diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index abc06915c7..40d0404bcc 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -61,7 +61,7 @@ extern "C" { pub fn cuda_drop_async(ptr: *mut c_void, v_stream: *const c_void) -> i32; /// Free memory for pointer `ptr` on GPU `gpu_index` synchronously - pub fn cuda_drop(ptr: *mut c_void) -> i32; + pub fn cuda_drop(ptr: *mut c_void, gpu_index: u32) -> i32; /// Get the maximum amount of shared memory on GPU `gpu_index` pub fn cuda_get_max_shared_memory(gpu_index: u32) -> i32; diff --git a/tfhe/src/core_crypto/gpu/algorithms/lwe_multi_bit_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/lwe_multi_bit_programmable_bootstrapping.rs index ecc91ee87c..18d29d0f81 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/lwe_multi_bit_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/lwe_multi_bit_programmable_bootstrapping.rs @@ -120,4 +120,5 @@ pub fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext( stream, ); } + stream.synchronize(); } diff --git a/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs index 51929b9193..bcbe1d5c93 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/lwe_programmable_bootstrapping.rs @@ -78,4 +78,5 @@ pub fn cuda_programmable_bootstrap_lwe_ciphertext( stream, ); } + stream.synchronize(); } diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs index 484a4e4bbd..2ffff8ba71 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_keyswitch.rs @@ -90,10 +90,10 @@ fn lwe_encrypt_ks_decrypt_custom_mod>( .iter() .map(|&x| >::cast_into(x)) .collect_vec(); - let mut d_input_indexes = stream.malloc_async::(num_blocks as u32); - let mut d_output_indexes = stream.malloc_async::(num_blocks as u32); - stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes); - stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes); + let mut d_input_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + let mut d_output_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + unsafe { stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes) }; + unsafe { stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes) }; cuda_keyswitch_lwe_ciphertext( &d_ksk_big_to_small, diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_multi_bit_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_multi_bit_programmable_bootstrapping.rs index e2b20a3f9e..bbde73a3c1 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_multi_bit_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_multi_bit_programmable_bootstrapping.rs @@ -145,8 +145,8 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod< } let mut d_test_vector_indexes = - stream.malloc_async::(number_of_messages as u32); - stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes); + unsafe { stream.malloc_async::(number_of_messages as u32) }; + unsafe { stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes) }; let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0; let lwe_indexes_usize: Vec = (0..num_blocks).collect_vec(); @@ -154,10 +154,12 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod< .iter() .map(|&x| >::cast_into(x)) .collect_vec(); - let mut d_output_indexes = stream.malloc_async::(num_blocks as u32); - let mut d_input_indexes = stream.malloc_async::(num_blocks as u32); - stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes); - stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes); + let mut d_output_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + let mut d_input_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + unsafe { + stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes); + stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes); + } cuda_multi_bit_programmable_bootstrap_lwe_ciphertext( &d_lwe_ciphertext_in, diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs index 7e046d9909..acb48ade96 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping.rs @@ -127,8 +127,8 @@ fn lwe_encrypt_pbs_decrypt< } let mut d_test_vector_indexes = - stream.malloc_async::(number_of_messages as u32); - stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes); + unsafe { stream.malloc_async::(number_of_messages as u32) }; + unsafe { stream.copy_to_gpu_async(&mut d_test_vector_indexes, &test_vector_indexes) }; let num_blocks = d_lwe_ciphertext_in.0.lwe_ciphertext_count.0; let lwe_indexes_usize: Vec = (0..num_blocks).collect_vec(); @@ -136,10 +136,12 @@ fn lwe_encrypt_pbs_decrypt< .iter() .map(|&x| >::cast_into(x)) .collect_vec(); - let mut d_output_indexes = stream.malloc_async::(num_blocks as u32); - let mut d_input_indexes = stream.malloc_async::(num_blocks as u32); - stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes); - stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes); + let mut d_output_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + let mut d_input_indexes = unsafe { stream.malloc_async::(num_blocks as u32) }; + unsafe { + stream.copy_to_gpu_async(&mut d_output_indexes, &lwe_indexes); + stream.copy_to_gpu_async(&mut d_input_indexes, &lwe_indexes); + } cuda_programmable_bootstrap_lwe_ciphertext( &d_lwe_ciphertext_in, diff --git a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs index 7b68b67100..c60a4c5071 100644 --- a/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/glwe_ciphertext_list.rs @@ -18,11 +18,13 @@ impl CudaGlweCiphertextList { stream: &CudaStream, ) -> Self { // Allocate memory in the device - let d_vec = stream.malloc_async( - (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) - * glwe_ciphertext_count.0) as u32, - ); - + let d_vec = unsafe { + stream.malloc_async( + (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) + * glwe_ciphertext_count.0) as u32, + ) + }; + stream.synchronize(); let cuda_glwe_list = CudaGlweList { d_vec, glwe_ciphertext_count, @@ -43,13 +45,17 @@ impl CudaGlweCiphertextList { let polynomial_size = h_ct.polynomial_size(); let ciphertext_modulus = h_ct.ciphertext_modulus(); - let mut d_vec = stream.malloc_async( - (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) - * glwe_ciphertext_count.0) as u32, - ); - + let mut d_vec = unsafe { + stream.malloc_async( + (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) + * glwe_ciphertext_count.0) as u32, + ) + }; // Copy to the GPU - stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref()); + unsafe { + stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref()); + stream.synchronize(); + } let cuda_glwe_list = CudaGlweList { d_vec, @@ -70,8 +76,10 @@ impl CudaGlweCiphertextList { * glwe_ciphertext_size(self.0.glwe_dimension.to_glwe_size(), self.0.polynomial_size); let mut container: Vec = vec![T::ZERO; glwe_ct_size]; - stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); - stream.synchronize(); + unsafe { + stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); + stream.synchronize(); + } GlweCiphertextList::from_container( container, @@ -90,14 +98,20 @@ impl CudaGlweCiphertextList { let polynomial_size = h_ct.polynomial_size(); let ciphertext_modulus = h_ct.ciphertext_modulus(); - let mut d_vec = stream.malloc_async( - (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) - * glwe_ciphertext_count.0) as u32, - ); + let mut d_vec = unsafe { + stream.malloc_async( + (glwe_ciphertext_size(glwe_dimension.to_glwe_size(), polynomial_size) + * glwe_ciphertext_count.0) as u32, + ) + }; // Copy to the GPU let h_input = h_ct.as_view().into_container(); - stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref()); + stream.synchronize(); + unsafe { + stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref()); + } + stream.synchronize(); let cuda_glwe_list = CudaGlweList { d_vec, diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_bootstrap_key.rs b/tfhe/src/core_crypto/gpu/entities/lwe_bootstrap_key.rs index 1dc9bf3cd9..87f7e80c5a 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_bootstrap_key.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_bootstrap_key.rs @@ -39,21 +39,25 @@ impl CudaLweBootstrapKey { let glwe_dimension = bsk.glwe_size().to_glwe_dimension(); // Allocate memory - let mut d_vec = stream.malloc_async::(lwe_bootstrap_key_size( - input_lwe_dimension, - glwe_dimension.to_glwe_size(), - polynomial_size, - decomp_level_count, - ) as u32); + let mut d_vec = unsafe { + stream.malloc_async::(lwe_bootstrap_key_size( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_level_count, + ) as u32) + }; // Copy to the GPU - stream.convert_lwe_bootstrap_key_async( - &mut d_vec, - bsk.as_ref(), - input_lwe_dimension, - glwe_dimension, - decomp_level_count, - polynomial_size, - ); + unsafe { + stream.convert_lwe_bootstrap_key_async( + &mut d_vec, + bsk.as_ref(), + input_lwe_dimension, + glwe_dimension, + decomp_level_count, + polynomial_size, + ); + } stream.synchronize(); Self { d_vec, diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs index bd5e39d7a8..61991bc34f 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs @@ -18,8 +18,10 @@ impl CudaLweCiphertextList { stream: &CudaStream, ) -> Self { // Allocate memory in the device - let d_vec = - stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32); + let d_vec = unsafe { + stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32) + }; + stream.synchronize(); let cuda_lwe_list = CudaLweList { d_vec, @@ -41,10 +43,13 @@ impl CudaLweCiphertextList { // Copy to the GPU let h_input = h_ct.as_view().into_container(); - let mut d_vec = - stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32); - stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref()); - stream.synchronize(); + let mut d_vec = unsafe { + stream.malloc_async((lwe_dimension.to_lwe_size().0 * lwe_ciphertext_count.0) as u32) + }; + unsafe { + stream.copy_to_gpu_async(&mut d_vec, h_input.as_ref()); + stream.synchronize(); + } let cuda_lwe_list = CudaLweList { d_vec, lwe_ciphertext_count, @@ -73,8 +78,10 @@ impl CudaLweCiphertextList { let lwe_ct_size = self.0.lwe_ciphertext_count.0 * self.0.lwe_dimension.to_lwe_size().0; let mut container: Vec = vec![T::ZERO; lwe_ct_size]; - stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); - stream.synchronize(); + unsafe { + stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); + stream.synchronize(); + } LweCiphertextList::from_container( container, @@ -92,8 +99,11 @@ impl CudaLweCiphertextList { let ciphertext_modulus = h_ct.ciphertext_modulus(); // Copy to the GPU - let mut d_vec = stream.malloc_async((lwe_dimension.to_lwe_size().0) as u32); - stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref()); + let mut d_vec = unsafe { stream.malloc_async((lwe_dimension.to_lwe_size().0) as u32) }; + unsafe { + stream.copy_to_gpu_async(&mut d_vec, h_ct.as_ref()); + } + stream.synchronize(); let cuda_lwe_list = CudaLweList { d_vec, @@ -108,7 +118,9 @@ impl CudaLweCiphertextList { let lwe_ct_size = self.0.lwe_dimension.to_lwe_size().0; let mut container: Vec = vec![T::ZERO; lwe_ct_size]; - stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); + unsafe { + stream.copy_to_cpu_async(container.as_mut_slice(), &self.0.d_vec); + } stream.synchronize(); LweCiphertext::from_container(container, self.ciphertext_modulus()) @@ -148,8 +160,11 @@ impl CudaLweCiphertextList { let ciphertext_modulus = self.ciphertext_modulus(); // Copy to the GPU - let mut d_vec = stream.malloc_async(self.0.d_vec.len() as u32); - stream.copy_gpu_to_gpu_async(&mut d_vec, &self.0.d_vec); + let mut d_vec = unsafe { stream.malloc_async(self.0.d_vec.len() as u32) }; + unsafe { + stream.copy_gpu_to_gpu_async(&mut d_vec, &self.0.d_vec); + } + stream.synchronize(); let cuda_lwe_list = CudaLweList { d_vec, diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs b/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs index d6bb2640c3..9f9407ed3a 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_keyswitch_key.rs @@ -28,15 +28,19 @@ impl CudaLweKeyswitchKey { let ciphertext_modulus = h_ksk.ciphertext_modulus(); // Allocate memory - let mut d_vec = stream.malloc_async::( - (input_lwe_size.to_lwe_dimension().0 - * lwe_keyswitch_key_input_key_element_encrypted_size( - decomp_level_count, - output_lwe_size, - )) as u32, - ); + let mut d_vec = unsafe { + stream.malloc_async::( + (input_lwe_size.to_lwe_dimension().0 + * lwe_keyswitch_key_input_key_element_encrypted_size( + decomp_level_count, + output_lwe_size, + )) as u32, + ) + }; - stream.convert_lwe_keyswitch_key_async(&mut d_vec, h_ksk.as_ref()); + unsafe { + stream.convert_lwe_keyswitch_key_async(&mut d_vec, h_ksk.as_ref()); + } stream.synchronize(); diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_multi_bit_bootstrap_key.rs b/tfhe/src/core_crypto/gpu/entities/lwe_multi_bit_bootstrap_key.rs index f23ed1db36..1ee7a20ce9 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_multi_bit_bootstrap_key.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_multi_bit_bootstrap_key.rs @@ -41,26 +41,30 @@ impl CudaLweMultiBitBootstrapKey { let grouping_factor = bsk.grouping_factor(); // Allocate memory - let mut d_vec = stream.malloc_async::( - lwe_multi_bit_bootstrap_key_size( + let mut d_vec = unsafe { + stream.malloc_async::( + lwe_multi_bit_bootstrap_key_size( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_level_count, + grouping_factor, + ) + .unwrap() as u32, + ) + }; + // Copy to the GPU + unsafe { + stream.convert_lwe_multi_bit_bootstrap_key_async( + &mut d_vec, + bsk.as_ref(), input_lwe_dimension, - glwe_dimension.to_glwe_size(), - polynomial_size, + glwe_dimension, decomp_level_count, + polynomial_size, grouping_factor, - ) - .unwrap() as u32, - ); - // Copy to the GPU - stream.convert_lwe_multi_bit_bootstrap_key_async( - &mut d_vec, - bsk.as_ref(), - input_lwe_dimension, - glwe_dimension, - decomp_level_count, - polynomial_size, - grouping_factor, - ); + ); + } stream.synchronize(); Self { d_vec, diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index ce65cabd4e..9cec947a59 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -60,34 +60,42 @@ impl CudaStream { } /// Allocates `elements` on the GPU asynchronously - pub fn malloc_async(&self, elements: u32) -> CudaVec + /// + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after the copy + /// as soon as synchronization is required + pub unsafe fn malloc_async(&self, elements: u32) -> CudaVec where T: Numeric, { let size = elements as u64 * std::mem::size_of::() as u64; - unsafe { - let ptr = CudaPtr { - ptr: cuda_malloc_async(size, self.as_c_ptr()), - device: self.device(), - }; + let ptr = CudaPtr { + ptr: cuda_malloc_async(size, self.as_c_ptr()), + device: self.device(), + }; - CudaVec::new(ptr, elements as usize, self.device()) - } + CudaVec::new(ptr, elements as usize, self.device()) } - pub fn memset_async(&self, dest: &mut CudaVec, value: T) + /// Sets data on the GPU to a specific `value` + /// + /// # Safety + /// + /// - `dest` __must__ be a valid pointer to the GPU global memory + /// - [CudaStream::synchronize] __must__ be called after the copy + /// as soon as synchronization is required + pub unsafe fn memset_async(&self, dest: &mut CudaVec, value: T) where T: Numeric + Into, { let dest_size = dest.len() * std::mem::size_of::(); - unsafe { - cuda_memset_async( - dest.as_mut_c_ptr(), - value.into(), - dest_size as u64, - self.as_c_ptr(), - ); - } + cuda_memset_async( + dest.as_mut_c_ptr(), + value.into(), + dest_size as u64, + self.as_c_ptr(), + ); } /// Copies data from slice into GPU pointer @@ -95,23 +103,21 @@ impl CudaStream { /// # Safety /// /// - `dest` __must__ be a valid pointer to the GPU global memory - /// - [CudaDevice::cuda_synchronize_device] __must__ be called after the copy + /// - [CudaStream::synchronize] __must__ be called after the copy /// as soon as synchronization is required - pub fn copy_to_gpu_async(&self, dest: &mut CudaVec, src: &[T]) + pub unsafe fn copy_to_gpu_async(&self, dest: &mut CudaVec, src: &[T]) where T: Numeric, { let src_size = std::mem::size_of_val(src); assert!(dest.len() * std::mem::size_of::() >= src_size); - unsafe { - cuda_memcpy_async_to_gpu( - dest.as_mut_c_ptr(), - src.as_ptr().cast(), - src_size as u64, - self.as_c_ptr(), - ); - } + cuda_memcpy_async_to_gpu( + dest.as_mut_c_ptr(), + src.as_ptr().cast(), + src_size as u64, + self.as_c_ptr(), + ); } /// Copies data between different arrays in the GPU @@ -120,23 +126,20 @@ impl CudaStream { /// /// - `src` __must__ be a valid pointer to the GPU global memory /// - `dest` __must__ be a valid pointer to the GPU global memory - /// - [CudaDevice::cuda_synchronize_device] __must__ be called after the copy + /// - [CudaStream::synchronize] __must__ be called after the copy /// as soon as synchronization is required - pub fn copy_gpu_to_gpu_async(&self, dest: &mut CudaVec, src: &CudaVec) + pub unsafe fn copy_gpu_to_gpu_async(&self, dest: &mut CudaVec, src: &CudaVec) where T: Numeric, { assert!(dest.len() >= src.len()); let size = dest.len() * std::mem::size_of::(); - - unsafe { - cuda_memcpy_async_gpu_to_gpu( - dest.as_mut_c_ptr(), - src.as_c_ptr(), - size as u64, - self.as_c_ptr(), - ); - } + cuda_memcpy_async_gpu_to_gpu( + dest.as_mut_c_ptr(), + src.as_c_ptr(), + size as u64, + self.as_c_ptr(), + ); } /// Copies data from GPU pointer into slice @@ -144,28 +147,31 @@ impl CudaStream { /// # Safety /// /// - `src` __must__ be a valid pointer to the GPU global memory - /// - [CudaDevice::cuda_synchronize_device] __must__ be called as soon as synchronization is + /// - [CudaStream::synchronize] __must__ be called as soon as synchronization is /// required - pub fn copy_to_cpu_async(&self, dest: &mut [T], src: &CudaVec) + pub unsafe fn copy_to_cpu_async(&self, dest: &mut [T], src: &CudaVec) where T: Numeric, { let dest_size = std::mem::size_of_val(dest); assert!(dest_size >= src.len() * std::mem::size_of::()); - unsafe { - cuda_memcpy_async_to_cpu( - dest.as_mut_ptr().cast(), - src.as_c_ptr(), - dest_size as u64, - self.as_c_ptr(), - ); - } + cuda_memcpy_async_to_cpu( + dest.as_mut_ptr().cast(), + src.as_c_ptr(), + dest_size as u64, + self.as_c_ptr(), + ); } /// Discarding bootstrap on a vector of LWE ciphertexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn bootstrap_low_latency_async( + pub unsafe fn bootstrap_low_latency_async( &self, lwe_array_out: &mut CudaVec, lwe_out_indexes: &CudaVec, @@ -183,44 +189,47 @@ impl CudaStream { lwe_idx: LweCiphertextIndex, ) { let mut pbs_buffer: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_bootstrap_low_latency_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(pbs_buffer), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - level.0 as u32, - num_samples, - self.device().get_max_shared_memory() as u32, - true, - ); - cuda_bootstrap_low_latency_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_out_indexes.as_c_ptr(), - test_vector.as_c_ptr(), - test_vector_indexes.as_c_ptr(), - lwe_array_in.as_c_ptr(), - lwe_in_indexes.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - pbs_buffer, - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - base_log.0 as u32, - level.0 as u32, - num_samples, - num_samples, - lwe_idx.0 as u32, - self.device().get_max_shared_memory() as u32, - ); - cleanup_cuda_bootstrap_low_latency(self.as_c_ptr(), std::ptr::addr_of_mut!(pbs_buffer)); - } + scratch_cuda_bootstrap_low_latency_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(pbs_buffer), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + level.0 as u32, + num_samples, + self.device().get_max_shared_memory() as u32, + true, + ); + cuda_bootstrap_low_latency_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_out_indexes.as_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + lwe_in_indexes.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + pbs_buffer, + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples, + num_samples, + lwe_idx.0 as u32, + self.device().get_max_shared_memory() as u32, + ); + cleanup_cuda_bootstrap_low_latency(self.as_c_ptr(), std::ptr::addr_of_mut!(pbs_buffer)); } /// Discarding bootstrap on a vector of LWE ciphertexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn bootstrap_multi_bit_async( + pub unsafe fn bootstrap_multi_bit_async( &self, lwe_array_out: &mut CudaVec, output_indexes: &CudaVec, @@ -239,48 +248,52 @@ impl CudaStream { lwe_idx: LweCiphertextIndex, ) { let mut pbs_buffer: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_multi_bit_pbs_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(pbs_buffer), - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - level.0 as u32, - grouping_factor.0 as u32, - num_samples, - self.device().get_max_shared_memory() as u32, - true, - 0u32, - ); - cuda_multi_bit_pbs_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - output_indexes.as_c_ptr(), - test_vector.as_c_ptr(), - test_vector_indexes.as_c_ptr(), - lwe_array_in.as_c_ptr(), - input_indexes.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - pbs_buffer, - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - grouping_factor.0 as u32, - base_log.0 as u32, - level.0 as u32, - num_samples, - num_samples, - lwe_idx.0 as u32, - self.device().get_max_shared_memory() as u32, - 0u32, - ); - cleanup_cuda_multi_bit_pbs(self.as_c_ptr(), std::ptr::addr_of_mut!(pbs_buffer)); - } + scratch_cuda_multi_bit_pbs_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(pbs_buffer), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + level.0 as u32, + grouping_factor.0 as u32, + num_samples, + self.device().get_max_shared_memory() as u32, + true, + 0u32, + ); + cuda_multi_bit_pbs_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + output_indexes.as_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + input_indexes.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + pbs_buffer, + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + grouping_factor.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples, + num_samples, + lwe_idx.0 as u32, + self.device().get_max_shared_memory() as u32, + 0u32, + ); + cleanup_cuda_multi_bit_pbs(self.as_c_ptr(), std::ptr::addr_of_mut!(pbs_buffer)); } + /// Discarding keyswitch on a vector of LWE ciphertexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn keyswitch_async( + pub unsafe fn keyswitch_async( &self, lwe_array_out: &mut CudaVec, lwe_out_indexes: &CudaVec, @@ -293,26 +306,29 @@ impl CudaStream { l_gadget: DecompositionLevelCount, num_samples: u32, ) { - unsafe { - cuda_keyswitch_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_out_indexes.as_c_ptr(), - lwe_array_in.as_c_ptr(), - lwe_in_indexes.as_c_ptr(), - keyswitch_key.as_c_ptr(), - input_lwe_dimension.0 as u32, - output_lwe_dimension.0 as u32, - base_log.0 as u32, - l_gadget.0 as u32, - num_samples, - ); - } - } - - /// Convert bootstrap key + cuda_keyswitch_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_out_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + lwe_in_indexes.as_c_ptr(), + keyswitch_key.as_c_ptr(), + input_lwe_dimension.0 as u32, + output_lwe_dimension.0 as u32, + base_log.0 as u32, + l_gadget.0 as u32, + num_samples, + ); + } + + /// Convert keyswitch key + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn convert_lwe_keyswitch_key_async( + pub unsafe fn convert_lwe_keyswitch_key_async( &self, dest: &mut CudaVec, src: &[T], @@ -321,8 +337,13 @@ impl CudaStream { } /// Convert bootstrap key + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn convert_lwe_bootstrap_key_async( + pub unsafe fn convert_lwe_bootstrap_key_async( &self, dest: &mut CudaVec, src: &[T], @@ -334,22 +355,25 @@ impl CudaStream { let size = std::mem::size_of_val(src); assert_eq!(dest.len() * std::mem::size_of::(), size); - unsafe { - cuda_convert_lwe_bootstrap_key_64( - dest.as_mut_c_ptr(), - src.as_ptr().cast(), - self.as_c_ptr(), - input_lwe_dim.0 as u32, - glwe_dim.0 as u32, - l_gadget.0 as u32, - polynomial_size.0 as u32, - ); - }; + cuda_convert_lwe_bootstrap_key_64( + dest.as_mut_c_ptr(), + src.as_ptr().cast(), + self.as_c_ptr(), + input_lwe_dim.0 as u32, + glwe_dim.0 as u32, + l_gadget.0 as u32, + polynomial_size.0 as u32, + ); } /// Convert multi-bit bootstrap key + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required #[allow(clippy::too_many_arguments)] - pub fn convert_lwe_multi_bit_bootstrap_key_async( + pub unsafe fn convert_lwe_multi_bit_bootstrap_key_async( &self, dest: &mut CudaVec, src: &[T], @@ -361,23 +385,25 @@ impl CudaStream { ) { let size = std::mem::size_of_val(src); assert_eq!(dest.len() * std::mem::size_of::(), size); - - unsafe { - cuda_convert_lwe_multi_bit_bootstrap_key_64( - dest.as_mut_c_ptr(), - src.as_ptr().cast(), - self.as_c_ptr(), - input_lwe_dim.0 as u32, - glwe_dim.0 as u32, - l_gadget.0 as u32, - polynomial_size.0 as u32, - grouping_factor.0 as u32, - ) - }; + cuda_convert_lwe_multi_bit_bootstrap_key_64( + dest.as_mut_c_ptr(), + src.as_ptr().cast(), + self.as_c_ptr(), + input_lwe_dim.0 as u32, + glwe_dim.0 as u32, + l_gadget.0 as u32, + polynomial_size.0 as u32, + grouping_factor.0 as u32, + ) } /// Discarding addition of a vector of LWE ciphertexts - pub fn add_lwe_ciphertext_vector_async( + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn add_lwe_ciphertext_vector_async( &self, lwe_array_out: &mut CudaVec, lwe_array_in_1: &CudaVec, @@ -385,40 +411,46 @@ impl CudaStream { lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_add_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_in_1.as_c_ptr(), - lwe_array_in_2.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_add_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_in_1.as_c_ptr(), + lwe_array_in_2.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } - /// Discarding addition of a vector of LWE ciphertexts - pub fn add_lwe_ciphertext_vector_assign_async( + /// Discarding assigned addition of a vector of LWE ciphertexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn add_lwe_ciphertext_vector_assign_async( &self, lwe_array_out: &mut CudaVec, lwe_array_in: &CudaVec, lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_add_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_out.as_c_ptr(), - lwe_array_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_add_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_out.as_c_ptr(), + lwe_array_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } - /// Discarding addition of a vector of LWE ciphertexts - pub fn add_lwe_ciphertext_vector_plaintext_vector_async( + /// Discarding addition of a vector of LWE ciphertexts with a vector of plaintexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn add_lwe_ciphertext_vector_plaintext_vector_async( &self, lwe_array_out: &mut CudaVec, lwe_array_in: &CudaVec, @@ -426,77 +458,90 @@ impl CudaStream { lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_add_lwe_ciphertext_vector_plaintext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_in.as_c_ptr(), - plaintext_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_add_lwe_ciphertext_vector_plaintext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_in.as_c_ptr(), + plaintext_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } - /// Discarding addition of a vector of LWE ciphertexts - pub fn add_lwe_ciphertext_vector_plaintext_vector_assign_async( + /// Discarding assigned addition of a vector of LWE ciphertexts with a vector of plaintexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn add_lwe_ciphertext_vector_plaintext_vector_assign_async( &self, lwe_array_out: &mut CudaVec, plaintext_in: &CudaVec, lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_add_lwe_ciphertext_vector_plaintext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_out.as_c_ptr(), - plaintext_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_add_lwe_ciphertext_vector_plaintext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_out.as_c_ptr(), + plaintext_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } /// Discarding negation of a vector of LWE ciphertexts - pub fn negate_lwe_ciphertext_vector_async( + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn negate_lwe_ciphertext_vector_async( &self, lwe_array_out: &mut CudaVec, lwe_array_in: &CudaVec, lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_negate_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_negate_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } - /// Discarding negation of a vector of LWE ciphertexts - pub fn negate_lwe_ciphertext_vector_assign_async( + /// Discarding assigned negation of a vector of LWE ciphertexts + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn negate_lwe_ciphertext_vector_assign_async( &self, lwe_array_out: &mut CudaVec, lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_negate_lwe_ciphertext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_out.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_negate_lwe_ciphertext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_out.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } #[allow(clippy::too_many_arguments)] - pub fn negate_integer_radix_assign_async( + /// Discarding assign negation of a vector of LWE ciphertexts representing an integer + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn negate_integer_radix_assign_async( &self, lwe_array: &mut CudaVec, lwe_dimension: LweDimension, @@ -504,40 +549,46 @@ impl CudaStream { message_modulus: u32, carry_modulus: u32, ) { - unsafe { - cuda_negate_integer_radix_ciphertext_64_inplace( - self.as_c_ptr(), - lwe_array.as_mut_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - message_modulus, - carry_modulus, - ); - } + cuda_negate_integer_radix_ciphertext_64_inplace( + self.as_c_ptr(), + lwe_array.as_mut_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + message_modulus, + carry_modulus, + ); } - /// Discarding negation of a vector of LWE ciphertexts - pub fn mult_lwe_ciphertext_vector_cleartext_vector_assign_async( + /// Multiplication of a vector of LWEs with a vector of cleartexts (assigned) + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn mult_lwe_ciphertext_vector_cleartext_vector_assign_async( &self, lwe_array: &mut CudaVec, cleartext_array_in: &CudaVec, lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_mult_lwe_ciphertext_vector_cleartext_vector_64( - self.as_c_ptr(), - lwe_array.as_mut_c_ptr(), - lwe_array.as_c_ptr(), - cleartext_array_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_mult_lwe_ciphertext_vector_cleartext_vector_64( + self.as_c_ptr(), + lwe_array.as_mut_c_ptr(), + lwe_array.as_c_ptr(), + cleartext_array_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } - /// Discarding negation of a vector of LWE ciphertexts - pub fn mult_lwe_ciphertext_vector_cleartext_vector( + /// Multiplication of a vector of LWEs with a vector of cleartexts. + /// + /// # Safety + /// + /// [CudaStream::synchronize] __must__ be called as soon as synchronization is + /// required + pub unsafe fn mult_lwe_ciphertext_vector_cleartext_vector( &self, lwe_array_out: &mut CudaVec, lwe_array_in: &CudaVec, @@ -545,16 +596,14 @@ impl CudaStream { lwe_dimension: LweDimension, num_samples: u32, ) { - unsafe { - cuda_mult_lwe_ciphertext_vector_cleartext_vector_64( - self.as_c_ptr(), - lwe_array_out.as_mut_c_ptr(), - lwe_array_in.as_c_ptr(), - cleartext_array_in.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - ); - } + cuda_mult_lwe_ciphertext_vector_cleartext_vector_64( + self.as_c_ptr(), + lwe_array_out.as_mut_c_ptr(), + lwe_array_in.as_c_ptr(), + cleartext_array_in.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + ); } } @@ -586,10 +635,7 @@ impl Drop for CudaPtr { let device = self.device; device.synchronize_device(); - // Release memory asynchronously so control returns to the CPU asap - // let stream = CudaStream::new_unchecked(device); - // unsafe { cuda_drop_async(self.ptr, stream.as_c_ptr(), device.gpu_index()) }; - unsafe { cuda_drop(self.as_mut_c_ptr()) }; + unsafe { cuda_drop(self.as_mut_c_ptr(), device.gpu_index()) }; } } @@ -668,11 +714,13 @@ mod tests { let gpu_index: u32 = 0; let device = CudaDevice::new(gpu_index); let stream = CudaStream::new_unchecked(device); - let mut d_vec: CudaVec = stream.malloc_async::(vec.len() as u32); - stream.copy_to_gpu_async(&mut d_vec, &vec); - let mut empty = vec![0_u64; vec.len()]; - stream.copy_to_cpu_async(&mut empty, &d_vec); - stream.synchronize(); - assert_eq!(vec, empty); + unsafe { + let mut d_vec: CudaVec = stream.malloc_async::(vec.len() as u32); + stream.copy_to_gpu_async(&mut d_vec, &vec); + let mut empty = vec![0_u64; vec.len()]; + stream.copy_to_cpu_async(&mut empty, &d_vec); + stream.synchronize(); + assert_eq!(vec, empty); + } } } diff --git a/tfhe/src/integer/gpu/ciphertext/mod.rs b/tfhe/src/integer/gpu/ciphertext/mod.rs index d6def4e4fd..070ce830bd 100644 --- a/tfhe/src/integer/gpu/ciphertext/mod.rs +++ b/tfhe/src/integer/gpu/ciphertext/mod.rs @@ -468,10 +468,13 @@ impl CudaRadixCiphertext { .flat_map(|block| block.ct.clone().into_container()) .collect::>(); - stream.copy_to_gpu_async( - &mut self.d_blocks.0.d_vec, - h_radix_ciphertext.as_mut_slice(), - ); + unsafe { + stream.copy_to_gpu_async( + &mut self.d_blocks.0.d_vec, + h_radix_ciphertext.as_mut_slice(), + ); + } + stream.synchronize(); self.info = CudaRadixCiphertextInfo { blocks: radix @@ -592,8 +595,10 @@ impl CudaRadixCiphertext { let mut self_container: Vec = vec![0; self_size]; let mut other_container: Vec = vec![0; other_size]; - stream.copy_to_cpu_async(self_container.as_mut_slice(), &self.d_blocks.0.d_vec); - stream.copy_to_cpu_async(other_container.as_mut_slice(), &other.d_blocks.0.d_vec); + unsafe { + stream.copy_to_cpu_async(self_container.as_mut_slice(), &self.d_blocks.0.d_vec); + stream.copy_to_cpu_async(other_container.as_mut_slice(), &other.d_blocks.0.d_vec); + } stream.synchronize(); self_container == other_container diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index df501166e8..d2f5e8ebfb 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -144,7 +144,11 @@ where impl CudaStream { #[allow(clippy::too_many_arguments)] - pub fn scalar_addition_integer_radix_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn scalar_addition_integer_radix_assign_async( &self, lwe_array: &mut CudaVec, scalar_input: &CudaVec, @@ -153,59 +157,65 @@ impl CudaStream { message_modulus: u32, carry_modulus: u32, ) { - unsafe { - cuda_scalar_addition_integer_radix_ciphertext_64_inplace( - self.as_c_ptr(), - lwe_array.as_mut_c_ptr(), - scalar_input.as_c_ptr(), - lwe_dimension.0 as u32, - num_samples, - message_modulus, - carry_modulus, - ); - } + cuda_scalar_addition_integer_radix_ciphertext_64_inplace( + self.as_c_ptr(), + lwe_array.as_mut_c_ptr(), + scalar_input.as_c_ptr(), + lwe_dimension.0 as u32, + num_samples, + message_modulus, + carry_modulus, + ); } - pub fn small_scalar_mult_integer_radix_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn small_scalar_mult_integer_radix_assign_async( &self, lwe_array: &mut CudaVec, scalar: u64, lwe_dimension: LweDimension, num_blocks: u32, ) { - unsafe { - cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( - self.as_c_ptr(), - lwe_array.as_mut_c_ptr(), - scalar, - lwe_dimension.0 as u32, - num_blocks, - ); - } + cuda_small_scalar_multiplication_integer_radix_ciphertext_64_inplace( + self.as_c_ptr(), + lwe_array.as_mut_c_ptr(), + scalar, + lwe_dimension.0 as u32, + num_blocks, + ); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_add_integer_radix_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_add_integer_radix_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, lwe_dimension: LweDimension, num_blocks: u32, ) { - unsafe { - cuda_add_lwe_ciphertext_vector_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - lwe_dimension.0 as u32, - num_blocks, - ); - } + cuda_add_lwe_ciphertext_vector_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + lwe_dimension.0 as u32, + num_blocks, + ); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_mul_integer_radix_classic_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_mul_integer_radix_classic_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_left: &CudaVec, @@ -224,53 +234,55 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - 0, - num_blocks, - PBSType::ClassicalLowLat as u32, - self.device().get_max_shared_memory() as u32, - true, - ); - cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - mem_ptr, - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - 0, - num_blocks, - PBSType::ClassicalLowLat as u32, - self.device().get_max_shared_memory() as u32, - ); - cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + 0, + num_blocks, + PBSType::ClassicalLowLat as u32, + self.device().get_max_shared_memory() as u32, + true, + ); + cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + mem_ptr, + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + 0, + num_blocks, + PBSType::ClassicalLowLat as u32, + self.device().get_max_shared_memory() as u32, + ); + cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_mul_integer_radix_multibit_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_mul_integer_radix_multibit_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_left: &CudaVec, @@ -290,53 +302,55 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - grouping_factor.0 as u32, - num_blocks, - PBSType::MultiBit as u32, - self.device().get_max_shared_memory() as u32, - true, - ); - cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - mem_ptr, - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - grouping_factor.0 as u32, - num_blocks, - PBSType::MultiBit as u32, - self.device().get_max_shared_memory() as u32, - ); - cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + grouping_factor.0 as u32, + num_blocks, + PBSType::MultiBit as u32, + self.device().get_max_shared_memory() as u32, + true, + ); + cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + mem_ptr, + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + grouping_factor.0 as u32, + num_blocks, + PBSType::MultiBit as u32, + self.device().get_max_shared_memory() as u32, + ); + cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_mul_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_mul_integer_radix_classic_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, @@ -354,53 +368,55 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - 0, - num_blocks, - PBSType::ClassicalLowLat as u32, - self.device().get_max_shared_memory() as u32, - true, - ); - cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - mem_ptr, - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - 0, - num_blocks, - PBSType::ClassicalLowLat as u32, - self.device().get_max_shared_memory() as u32, - ); - cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + 0, + num_blocks, + PBSType::ClassicalLowLat as u32, + self.device().get_max_shared_memory() as u32, + true, + ); + cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + mem_ptr, + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + 0, + num_blocks, + PBSType::ClassicalLowLat as u32, + self.device().get_max_shared_memory() as u32, + ); + cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_mul_integer_radix_multibit_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_mul_integer_radix_multibit_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, @@ -419,53 +435,55 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - grouping_factor.0 as u32, - num_blocks, - PBSType::MultiBit as u32, - self.device().get_max_shared_memory() as u32, - true, - ); - cuda_integer_mult_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - mem_ptr, - message_modulus.0 as u32, - carry_modulus.0 as u32, - glwe_dimension.0 as u32, - lwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - grouping_factor.0 as u32, - num_blocks, - PBSType::MultiBit as u32, - self.device().get_max_shared_memory() as u32, - ); - cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + grouping_factor.0 as u32, + num_blocks, + PBSType::MultiBit as u32, + self.device().get_max_shared_memory() as u32, + true, + ); + cuda_integer_mult_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + mem_ptr, + message_modulus.0 as u32, + carry_modulus.0 as u32, + glwe_dimension.0 as u32, + lwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + grouping_factor.0 as u32, + num_blocks, + PBSType::MultiBit as u32, + self.device().get_max_shared_memory() as u32, + ); + cleanup_cuda_integer_mult(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_bitop_integer_radix_classic_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_bitop_integer_radix_classic_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_left: &CudaVec, @@ -486,42 +504,44 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); - cuda_bitop_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); + cuda_bitop_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_bitop_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_bitop_integer_radix_classic_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, @@ -541,42 +561,44 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); - cuda_bitop_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); + cuda_bitop_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_bitnot_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_bitnot_integer_radix_classic_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -594,41 +616,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - BitOpType::Not as u32, - true, - ); - cuda_bitnot_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + BitOpType::Not as u32, + true, + ); + cuda_bitnot_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_bitnot_integer_radix_multibit_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_bitnot_integer_radix_multibit_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -647,41 +671,45 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - BitOpType::Not as u32, - true, - ); - cuda_bitnot_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + BitOpType::Not as u32, + true, + ); + cuda_bitnot_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_bitop_integer_radix_multibit_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_bitop_integer_radix_multibit_kb_assign_async< + T: UnsignedInteger, + >( &self, radix_lwe: &mut CudaVec, clear_blocks: &CudaVec, @@ -702,44 +730,48 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - op as u32, - true, - ); - cuda_scalar_bitop_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe.as_mut_c_ptr(), - radix_lwe.as_mut_c_ptr(), - clear_blocks.as_c_ptr(), - clear_blocks.len() as u32, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - op as u32, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + op as u32, + true, + ); + cuda_scalar_bitop_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe.as_mut_c_ptr(), + radix_lwe.as_mut_c_ptr(), + clear_blocks.as_c_ptr(), + clear_blocks.len() as u32, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + op as u32, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_bitop_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_bitop_integer_radix_classic_kb_assign_async< + T: UnsignedInteger, + >( &self, radix_lwe: &mut CudaVec, clear_blocks: &CudaVec, @@ -759,44 +791,46 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); - cuda_scalar_bitop_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe.as_mut_c_ptr(), - radix_lwe.as_mut_c_ptr(), - clear_blocks.as_c_ptr(), - clear_blocks.len() as u32, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - op as u32, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); + cuda_scalar_bitop_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe.as_mut_c_ptr(), + radix_lwe.as_mut_c_ptr(), + clear_blocks.as_c_ptr(), + clear_blocks.len() as u32, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + op as u32, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_bitop_integer_radix_multibit_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_bitop_integer_radix_multibit_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, @@ -817,42 +851,44 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_bitop_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - op as u32, - true, - ); - cuda_bitop_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_bitop_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + op as u32, + true, + ); + cuda_bitop_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_bitop(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_comparison_integer_radix_classic_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_comparison_integer_radix_classic_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_left: &CudaVec, @@ -873,44 +909,46 @@ impl CudaStream { op: ComparisonType, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_comparison_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); + scratch_cuda_integer_radix_comparison_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); - cuda_comparison_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); + cuda_comparison_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); - cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_comparison_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_comparison_integer_radix_classic_kb_assign_async( &self, radix_lwe_left: &mut CudaVec, radix_lwe_right: &CudaVec, @@ -930,42 +968,44 @@ impl CudaStream { op: ComparisonType, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_comparison_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); - cuda_comparison_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_comparison_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); + cuda_comparison_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_comparison_integer_radix_multibit_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_comparison_integer_radix_multibit_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_left: &CudaVec, @@ -987,42 +1027,44 @@ impl CudaStream { op: ComparisonType, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_comparison_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - op as u32, - true, - ); - cuda_comparison_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_left.as_c_ptr(), - radix_lwe_right.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_comparison_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + op as u32, + true, + ); + cuda_comparison_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_left.as_c_ptr(), + radix_lwe_right.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_comparison_integer_radix_classic_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_comparison_integer_radix_classic_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_in: &CudaVec, @@ -1044,45 +1086,49 @@ impl CudaStream { op: ComparisonType, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_comparison_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - op as u32, - true, - ); + scratch_cuda_integer_radix_comparison_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + op as u32, + true, + ); - cuda_scalar_comparison_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_in.as_c_ptr(), - scalar_blocks.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - num_scalar_blocks, - ); + cuda_scalar_comparison_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_in.as_c_ptr(), + scalar_blocks.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + num_scalar_blocks, + ); - cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_comparison_integer_radix_multibit_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_comparison_integer_radix_multibit_kb_async< + T: UnsignedInteger, + >( &self, radix_lwe_out: &mut CudaVec, radix_lwe_in: &CudaVec, @@ -1105,43 +1151,45 @@ impl CudaStream { op: ComparisonType, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_comparison_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - op as u32, - true, - ); - cuda_scalar_comparison_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_in.as_c_ptr(), - scalar_blocks.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - num_scalar_blocks, - ); - cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_comparison_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + op as u32, + true, + ); + cuda_scalar_comparison_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_in.as_c_ptr(), + scalar_blocks.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + num_scalar_blocks, + ); + cleanup_cuda_integer_comparison(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn full_propagate_classic_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn full_propagate_classic_assign_async( &self, radix_lwe_input: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -1158,43 +1206,45 @@ impl CudaStream { carry_modulus: CarryModulus, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_full_propagation_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_level.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - true, - ); - cuda_full_propagation_64_inplace( - self.as_c_ptr(), - radix_lwe_input.as_mut_c_ptr(), - mem_ptr, - keyswitch_key.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - 0, - num_blocks, - ); - cleanup_cuda_full_propagation(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_full_propagation_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_level.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + true, + ); + cuda_full_propagation_64_inplace( + self.as_c_ptr(), + radix_lwe_input.as_mut_c_ptr(), + mem_ptr, + keyswitch_key.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + 0, + num_blocks, + ); + cleanup_cuda_full_propagation(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn full_propagate_multibit_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn full_propagate_multibit_assign_async( &self, radix_lwe_input: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -1212,43 +1262,45 @@ impl CudaStream { carry_modulus: CarryModulus, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_full_propagation_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - pbs_level.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - true, - ); - cuda_full_propagation_64_inplace( - self.as_c_ptr(), - radix_lwe_input.as_mut_c_ptr(), - mem_ptr, - keyswitch_key.as_c_ptr(), - bootstrapping_key.as_c_ptr(), - lwe_dimension.0 as u32, - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - ks_base_log.0 as u32, - ks_level.0 as u32, - pbs_base_log.0 as u32, - pbs_level.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - ); - cleanup_cuda_full_propagation(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_full_propagation_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + pbs_level.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + true, + ); + cuda_full_propagation_64_inplace( + self.as_c_ptr(), + radix_lwe_input.as_mut_c_ptr(), + mem_ptr, + keyswitch_key.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + ks_base_log.0 as u32, + ks_level.0 as u32, + pbs_base_log.0 as u32, + pbs_level.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + ); + cleanup_cuda_full_propagation(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn propagate_single_carry_classic_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn propagate_single_carry_classic_assign_async( &self, radix_lwe_input: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -1265,43 +1317,45 @@ impl CudaStream { carry_modulus: CarryModulus, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32; - scratch_cuda_propagate_single_carry_low_latency_kb_64_inplace( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension, - lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - true, - ); - cuda_propagate_single_carry_low_latency_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_input.as_mut_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_propagate_single_carry_low_latency( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32; + scratch_cuda_propagate_single_carry_low_latency_kb_64_inplace( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + true, + ); + cuda_propagate_single_carry_low_latency_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_input.as_mut_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_propagate_single_carry_low_latency( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + ); } #[allow(clippy::too_many_arguments)] - pub fn propagate_single_carry_multibit_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn propagate_single_carry_multibit_assign_async( &self, radix_lwe_input: &mut CudaVec, bootstrapping_key: &CudaVec, @@ -1319,43 +1373,47 @@ impl CudaStream { carry_modulus: CarryModulus, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32; - scratch_cuda_propagate_single_carry_low_latency_kb_64_inplace( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension, - lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - true, - ); - cuda_propagate_single_carry_low_latency_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_input.as_mut_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_propagate_single_carry_low_latency( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + let big_lwe_dimension: u32 = glwe_dimension.0 as u32 * polynomial_size.0 as u32; + scratch_cuda_propagate_single_carry_low_latency_kb_64_inplace( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + true, + ); + cuda_propagate_single_carry_low_latency_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_input.as_mut_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_propagate_single_carry_low_latency( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + ); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_shift_left_integer_radix_classic_kb_assign_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_shift_left_integer_radix_classic_kb_assign_async< + T: UnsignedInteger, + >( &self, radix_lwe_left: &mut CudaVec, shift: u32, @@ -1374,44 +1432,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_shift_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - ShiftType::Left as u32, - true, - ); - cuda_integer_radix_scalar_shift_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - shift, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_shift( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_shift_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + ShiftType::Left as u32, + true, + ); + cuda_integer_radix_scalar_shift_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + shift, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_shift(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_shift_left_integer_radix_multibit_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_shift_left_integer_radix_multibit_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1433,44 +1490,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_shift_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - ShiftType::Left as u32, - true, - ); - cuda_integer_radix_scalar_shift_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - shift, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_shift( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_shift_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + ShiftType::Left as u32, + true, + ); + cuda_integer_radix_scalar_shift_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + shift, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_shift(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_shift_right_integer_radix_classic_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_shift_right_integer_radix_classic_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1491,44 +1547,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_shift_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - ShiftType::Right as u32, - true, - ); - cuda_integer_radix_scalar_shift_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - shift, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_shift( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_shift_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + ShiftType::Right as u32, + true, + ); + cuda_integer_radix_scalar_shift_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + shift, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_shift(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_shift_right_integer_radix_multibit_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_shift_right_integer_radix_multibit_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1550,44 +1605,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_shift_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - ShiftType::Right as u32, - true, - ); - cuda_integer_radix_scalar_shift_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - shift, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_shift( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_shift_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + ShiftType::Right as u32, + true, + ); + cuda_integer_radix_scalar_shift_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + shift, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_shift(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_cmux_integer_radix_classic_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_cmux_integer_radix_classic_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_condition: &CudaVec, @@ -1608,42 +1662,44 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_cmux_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - true, - ); - cuda_cmux_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_condition.as_c_ptr(), - radix_lwe_true.as_c_ptr(), - radix_lwe_false.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_cmux(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_cmux_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + true, + ); + cuda_cmux_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_condition.as_c_ptr(), + radix_lwe_true.as_c_ptr(), + radix_lwe_false.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_cmux(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_cmux_integer_radix_multibit_kb_async( + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_cmux_integer_radix_multibit_kb_async( &self, radix_lwe_out: &mut CudaVec, radix_lwe_condition: &CudaVec, @@ -1665,42 +1721,44 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_cmux_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - true, - ); - cuda_cmux_integer_radix_ciphertext_kb_64( - self.as_c_ptr(), - radix_lwe_out.as_mut_c_ptr(), - radix_lwe_condition.as_c_ptr(), - radix_lwe_true.as_c_ptr(), - radix_lwe_false.as_c_ptr(), - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_cmux(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); - } + scratch_cuda_integer_radix_cmux_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + true, + ); + cuda_cmux_integer_radix_ciphertext_kb_64( + self.as_c_ptr(), + radix_lwe_out.as_mut_c_ptr(), + radix_lwe_condition.as_c_ptr(), + radix_lwe_true.as_c_ptr(), + radix_lwe_false.as_c_ptr(), + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_cmux(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_rotate_left_integer_radix_classic_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_rotate_left_integer_radix_classic_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1721,44 +1779,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_rotate_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - ShiftType::Left as u32, - true, - ); - cuda_integer_radix_scalar_rotate_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - n, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_rotate( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_rotate_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + ShiftType::Left as u32, + true, + ); + cuda_integer_radix_scalar_rotate_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + n, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_rotate(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_rotate_left_integer_radix_multibit_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_rotate_left_integer_radix_multibit_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1780,44 +1837,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_rotate_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - ShiftType::Left as u32, - true, - ); - cuda_integer_radix_scalar_rotate_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - n, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_rotate( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_rotate_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + ShiftType::Left as u32, + true, + ); + cuda_integer_radix_scalar_rotate_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + n, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_rotate(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_rotate_right_integer_radix_classic_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_rotate_right_integer_radix_classic_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1838,44 +1894,43 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_rotate_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - 0, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::ClassicalLowLat as u32, - ShiftType::Right as u32, - true, - ); - cuda_integer_radix_scalar_rotate_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - n, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_rotate( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_rotate_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + 0, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::ClassicalLowLat as u32, + ShiftType::Right as u32, + true, + ); + cuda_integer_radix_scalar_rotate_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + n, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_rotate(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } #[allow(clippy::too_many_arguments)] - pub fn unchecked_scalar_rotate_right_integer_radix_multibit_kb_assign_async< + /// # Safety + /// + /// - [CudaStream::synchronize] __must__ be called after this function + /// as soon as synchronization is required + pub unsafe fn unchecked_scalar_rotate_right_integer_radix_multibit_kb_assign_async< T: UnsignedInteger, >( &self, @@ -1897,39 +1952,34 @@ impl CudaStream { num_blocks: u32, ) { let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - unsafe { - scratch_cuda_integer_radix_scalar_rotate_kb_64( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - pbs_grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - PBSType::MultiBit as u32, - ShiftType::Right as u32, - true, - ); - cuda_integer_radix_scalar_rotate_kb_64_inplace( - self.as_c_ptr(), - radix_lwe_left.as_mut_c_ptr(), - n, - mem_ptr, - bootstrapping_key.as_c_ptr(), - keyswitch_key.as_c_ptr(), - num_blocks, - ); - cleanup_cuda_integer_radix_scalar_rotate( - self.as_c_ptr(), - std::ptr::addr_of_mut!(mem_ptr), - ); - } + scratch_cuda_integer_radix_scalar_rotate_kb_64( + self.as_c_ptr(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + pbs_grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + PBSType::MultiBit as u32, + ShiftType::Right as u32, + true, + ); + cuda_integer_radix_scalar_rotate_kb_64_inplace( + self.as_c_ptr(), + radix_lwe_left.as_mut_c_ptr(), + n, + mem_ptr, + bootstrapping_key.as_c_ptr(), + keyswitch_key.as_c_ptr(), + num_blocks, + ); + cleanup_cuda_integer_radix_scalar_rotate(self.as_c_ptr(), std::ptr::addr_of_mut!(mem_ptr)); } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs index 2ad128c75c..cf88dad3ac 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_rotate.rs @@ -253,7 +253,6 @@ impl CudaServerKey { { let mut result = unsafe { ct.duplicate_async(stream) }; self.scalar_rotate_left_assign(&mut result, shift, stream); - stream.synchronize(); result } @@ -269,7 +268,6 @@ impl CudaServerKey { { let mut result = unsafe { ct.duplicate_async(stream) }; self.scalar_rotate_right_assign(&mut result, shift, stream); - stream.synchronize(); result } }