diff --git a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs index 113e4f1616..e4f257bf2a 100644 --- a/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs +++ b/tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs @@ -220,4 +220,19 @@ impl CudaLweCiphertextList { pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus { self.0.ciphertext_modulus } + + /// # Safety + /// + /// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must + /// not be dropped until stream is synchronised + pub unsafe fn set_to_zero_async(&mut self, streams: &CudaStreams) { + self.0.d_vec.memset_async(0u64, streams, 0); + } + + pub fn set_to_zero(&mut self, streams: &CudaStreams) { + unsafe { + self.set_to_zero_async(streams); + streams.synchronize_one(0); + } + } } diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index 8004cab6c2..ba6deff7a1 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -175,16 +175,13 @@ impl CudaVec { /// /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must /// not be dropped until streams is synchronised - pub unsafe fn memset_async(&mut self, value: T, streams: &CudaStreams, stream_index: u32) - where - T: Into, - { + pub unsafe fn memset_async(&mut self, value: u64, streams: &CudaStreams, stream_index: u32) { let size = self.len() * std::mem::size_of::(); // We check that self is not empty to avoid invalid pointers if size > 0 { cuda_memset_async( self.as_mut_c_ptr(stream_index), - value.into(), + value, size as u64, streams.ptr[stream_index as usize], streams.gpu_indexes[stream_index as usize].0,