Skip to content

Commit

Permalink
feat(gpu): add a function to set a CudaLweList to 0
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Mar 5, 2025
1 parent 75626ac commit 3c1687b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
15 changes: 15 additions & 0 deletions tfhe/src/core_crypto/gpu/entities/lwe_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,19 @@ impl<T: UnsignedInteger> CudaLweCiphertextList<T> {
pub(crate) fn ciphertext_modulus(&self) -> CiphertextModulus<T> {
self.0.ciphertext_modulus
}

/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn set_to_zero_async(&mut self, streams: &CudaStreams) {
self.0.d_vec.memset_async(0u64, streams, 0);
}

pub fn set_to_zero(&mut self, streams: &CudaStreams) {
unsafe {
self.set_to_zero_async(streams);
streams.synchronize_one(0);
}
}
}
7 changes: 2 additions & 5 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,13 @@ impl<T: Numeric> CudaVec<T> {
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn memset_async(&mut self, value: T, streams: &CudaStreams, stream_index: u32)
where
T: Into<u64>,
{
pub unsafe fn memset_async(&mut self, value: u64, streams: &CudaStreams, stream_index: u32) {
let size = self.len() * std::mem::size_of::<T>();
// We check that self is not empty to avoid invalid pointers
if size > 0 {
cuda_memset_async(
self.as_mut_c_ptr(stream_index),
value.into(),
value,
size as u64,
streams.ptr[stream_index as usize],
streams.gpu_indexes[stream_index as usize].0,
Expand Down

0 comments on commit 3c1687b

Please sign in to comment.