Skip to content

Commit

Permalink
fix(gpu): return early in sum_ct if num radix is 2, pass different po…
Browse files Browse the repository at this point in the history
…inters to smart copy
  • Loading branch information
agnesLeroy committed Sep 12, 2024
1 parent 345f25c commit 9dca245
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
9 changes: 8 additions & 1 deletion backends/tfhe-cuda-backend/cuda/include/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,7 @@ template <typename Torus> struct int_overflowing_sub_memory {

template <typename Torus> struct int_sum_ciphertexts_vec_memory {
Torus *new_blocks;
Torus *new_blocks_copy;
Torus *old_blocks;
Torus *small_lwe_vector;
int_radix_params params;
Expand Down Expand Up @@ -1384,6 +1385,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
new_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
old_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
Expand Down Expand Up @@ -1415,6 +1419,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
this->new_blocks = new_blocks;
this->old_blocks = old_blocks;
this->small_lwe_vector = small_lwe_vector;
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);

d_smart_copy_in = (int32_t *)cuda_malloc_async(
max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]);
Expand All @@ -1433,8 +1440,8 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
cuda_drop_async(small_lwe_vector, streams[0], gpu_indexes[0]);
}

cuda_drop_async(new_blocks_copy, streams[0], gpu_indexes[0]);
scp_mem->release(streams, gpu_indexes, gpu_count);

delete scp_mem;
}
};
Expand Down
12 changes: 11 additions & 1 deletion backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
int_radix_lut<Torus> *reused_lut) {

auto new_blocks = mem_ptr->new_blocks;
auto new_blocks_copy = mem_ptr->new_blocks_copy;
auto old_blocks = mem_ptr->old_blocks;
auto small_lwe_vector = mem_ptr->small_lwe_vector;

Expand Down Expand Up @@ -220,6 +221,12 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
big_lwe_size * sizeof(Torus),
streams[0], gpu_indexes[0]);
}
if (num_radix_in_vec == 2) {
host_addition<Torus>(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks,
&old_blocks[num_blocks * big_lwe_size],
big_lwe_dimension, num_blocks);
return;
}

size_t r = num_radix_in_vec;
size_t total_modulus = message_modulus * carry_modulus;
Expand Down Expand Up @@ -310,8 +317,11 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
// inside d_smart_copy_in there are only -1 values
// it's fine to call smart_copy with same pointer
// as source and destination
cuda_memcpy_async_gpu_to_gpu(new_blocks_copy, new_blocks,
r * num_blocks * big_lwe_size * sizeof(Torus),
streams[0], gpu_indexes[0]);
smart_copy<Torus><<<sm_copy_count, 1024, 0, streams[0]>>>(
new_blocks, new_blocks, d_smart_copy_out, d_smart_copy_in,
new_blocks, new_blocks_copy, d_smart_copy_out, d_smart_copy_in,
big_lwe_size);
check_cuda_error(cudaGetLastError());

Expand Down

0 comments on commit 9dca245

Please sign in to comment.