Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Al/refactor mul and scalar mul to track degree and noise level #2114

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions backends/tfhe-cuda-backend/cuda/include/integer/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,11 @@ void scratch_cuda_integer_mult_radix_ciphertext_kb_64(

void cuda_integer_mult_radix_ciphertext_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *radix_lwe_out, void const *radix_lwe_left, bool const is_bool_left,
void const *radix_lwe_right, bool const is_bool_right, void *const *bsks,
void *const *ksks, int8_t *mem_ptr, uint32_t polynomial_size,
uint32_t num_blocks);
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI const *radix_lwe_left, bool const is_bool_left,
CudaRadixCiphertextFFI const *radix_lwe_right, bool const is_bool_right,
void *const *bsks, void *const *ksks, int8_t *mem_ptr,
uint32_t polynomial_size, uint32_t num_blocks);

void cleanup_cuda_integer_mult(void *const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
Expand Down Expand Up @@ -375,9 +376,9 @@ void scratch_cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(

void cuda_integer_radix_partial_sum_ciphertexts_vec_kb_64(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *radix_lwe_out, void *radix_lwe_vec, uint32_t num_radix_in_vec,
int8_t *mem_ptr, void *const *bsks, void *const *ksks,
uint32_t num_blocks_in_radix);
CudaRadixCiphertextFFI *radix_lwe_out,
CudaRadixCiphertextFFI *radix_lwe_vec, int8_t *mem_ptr, void *const *bsks,
void *const *ksks);

void cleanup_cuda_integer_radix_partial_sum_ciphertexts_vec(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
Expand All @@ -393,10 +394,10 @@ void scratch_cuda_integer_scalar_mul_kb_64(

void cuda_scalar_multiplication_integer_radix_ciphertext_64_inplace(
void *const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count,
void *lwe_array, uint64_t const *decomposed_scalar,
CudaRadixCiphertextFFI *lwe_array, uint64_t const *decomposed_scalar,
uint64_t const *has_at_least_one_set, int8_t *mem_ptr, void *const *bsks,
void *const *ksks, uint32_t lwe_dimension, uint32_t polynomial_size,
uint32_t message_modulus, uint32_t num_blocks, uint32_t num_scalars);
void *const *ksks, uint32_t polynomial_size, uint32_t message_modulus,
uint32_t num_scalars);

void cleanup_cuda_integer_radix_scalar_mul(void *const *streams,
uint32_t const *gpu_indexes,
Expand Down
151 changes: 71 additions & 80 deletions backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ template <typename Torus>
__global__ void radix_blocks_rotate_right(Torus *dst, Torus *src,
uint32_t value, uint32_t blocks_count,
uint32_t lwe_size);
void generate_ids_update_degrees(int *terms_degree, size_t *h_lwe_idx_in,
void generate_ids_update_degrees(uint64_t *terms_degree, size_t *h_lwe_idx_in,
size_t *h_lwe_idx_out,
int32_t *h_smart_copy_in,
int32_t *h_smart_copy_out, size_t ch_amount,
Expand Down Expand Up @@ -1161,10 +1161,10 @@ template <typename Torus> struct int_overflowing_sub_memory {
};

template <typename Torus> struct int_sum_ciphertexts_vec_memory {
Torus *new_blocks;
Torus *new_blocks_copy;
Torus *old_blocks;
Torus *small_lwe_vector;
CudaRadixCiphertextFFI *new_blocks;
CudaRadixCiphertextFFI *new_blocks_copy;
CudaRadixCiphertextFFI *old_blocks;
CudaRadixCiphertextFFI *small_lwe_vector;
int_radix_params params;

int32_t *d_smart_copy_in;
Expand All @@ -1183,34 +1183,22 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
int max_pbs_count = num_blocks_in_radix * max_num_radix_in_vec;

// allocate gpu memory for intermediate buffers
new_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
old_blocks = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
small_lwe_vector = (Torus *)cuda_malloc_async(
max_pbs_count * (params.small_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_memset_async(new_blocks, 0,
max_pbs_count * (params.big_lwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_memset_async(new_blocks_copy, 0,
max_pbs_count * (params.big_lwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_memset_async(old_blocks, 0,
max_pbs_count * (params.big_lwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_memset_async(small_lwe_vector, 0,
max_pbs_count * (params.small_lwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
new_blocks = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
new_blocks, max_pbs_count,
params.big_lwe_dimension);
new_blocks_copy = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
new_blocks_copy, max_pbs_count,
params.big_lwe_dimension);
old_blocks = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
old_blocks, max_pbs_count,
params.big_lwe_dimension);
small_lwe_vector = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
small_lwe_vector, max_pbs_count,
params.small_lwe_dimension);

d_smart_copy_in = (int32_t *)cuda_malloc_async(
max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]);
Expand All @@ -1227,8 +1215,9 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
uint32_t gpu_count, int_radix_params params,
uint32_t num_blocks_in_radix,
uint32_t max_num_radix_in_vec,
Torus *new_blocks, Torus *old_blocks,
Torus *small_lwe_vector) {
CudaRadixCiphertextFFI *new_blocks,
CudaRadixCiphertextFFI *old_blocks,
CudaRadixCiphertextFFI *small_lwe_vector) {
mem_reuse = true;
this->params = params;

Expand All @@ -1238,13 +1227,10 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
this->new_blocks = new_blocks;
this->old_blocks = old_blocks;
this->small_lwe_vector = small_lwe_vector;
new_blocks_copy = (Torus *)cuda_malloc_async(
max_pbs_count * (params.big_lwe_dimension + 1) * sizeof(Torus),
streams[0], gpu_indexes[0]);
cuda_memset_async(new_blocks_copy, 0,
max_pbs_count * (params.big_lwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
new_blocks_copy = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0],
new_blocks_copy, max_pbs_count,
params.big_lwe_dimension);

d_smart_copy_in = (int32_t *)cuda_malloc_async(
max_pbs_count * sizeof(int32_t), streams[0], gpu_indexes[0]);
Expand All @@ -1262,12 +1248,15 @@ template <typename Torus> struct int_sum_ciphertexts_vec_memory {
cuda_drop_async(d_smart_copy_out, streams[0], gpu_indexes[0]);

if (!mem_reuse) {
cuda_drop_async(new_blocks, streams[0], gpu_indexes[0]);
cuda_drop_async(old_blocks, streams[0], gpu_indexes[0]);
cuda_drop_async(small_lwe_vector, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], new_blocks);
delete new_blocks;
release_radix_ciphertext(streams[0], gpu_indexes[0], old_blocks);
delete old_blocks;
release_radix_ciphertext(streams[0], gpu_indexes[0], small_lwe_vector);
delete small_lwe_vector;
}

cuda_drop_async(new_blocks_copy, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], new_blocks_copy);
delete new_blocks_copy;
}
};
// For sequential algorithm in group propagation
Expand Down Expand Up @@ -2482,7 +2471,7 @@ template <typename Torus> struct int_zero_out_if_buffer {

int_radix_params params;

Torus *tmp;
CudaRadixCiphertextFFI *tmp;

cudaStream_t *true_streams;
cudaStream_t *false_streams;
Expand All @@ -2495,10 +2484,11 @@ template <typename Torus> struct int_zero_out_if_buffer {
this->params = params;
active_gpu_count = get_active_gpu_count(num_radix_blocks, gpu_count);

Torus big_size =
(params.big_lwe_dimension + 1) * num_radix_blocks * sizeof(Torus);
if (allocate_gpu_memory) {
tmp = (Torus *)cuda_malloc_async(big_size, streams[0], gpu_indexes[0]);
tmp = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(streams[0], gpu_indexes[0], tmp,
num_radix_blocks,
params.big_lwe_dimension);
// We may use a different stream to allow concurrent operation
true_streams =
(cudaStream_t *)malloc(active_gpu_count * sizeof(cudaStream_t));
Expand All @@ -2512,7 +2502,8 @@ template <typename Torus> struct int_zero_out_if_buffer {
}
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count) {
cuda_drop_async(tmp, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], tmp);
delete tmp;
for (uint j = 0; j < active_gpu_count; j++) {
cuda_destroy_stream(true_streams[j], gpu_indexes[j]);
cuda_destroy_stream(false_streams[j], gpu_indexes[j]);
Expand All @@ -2523,9 +2514,9 @@ template <typename Torus> struct int_zero_out_if_buffer {
};

template <typename Torus> struct int_mul_memory {
Torus *vector_result_sb;
Torus *block_mul_res;
Torus *small_lwe_vector;
CudaRadixCiphertextFFI *vector_result_sb;
CudaRadixCiphertextFFI *block_mul_res;
CudaRadixCiphertextFFI *small_lwe_vector;

int_radix_lut<Torus> *luts_array; // lsb msb
int_radix_lut<Torus> *zero_out_predicate_lut;
Expand Down Expand Up @@ -2574,7 +2565,6 @@ template <typename Torus> struct int_mul_memory {
auto polynomial_size = params.polynomial_size;
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
auto lwe_dimension = params.small_lwe_dimension;

// 'vector_result_lsb' contains blocks from all possible shifts of
// radix_lwe_left excluding zero ciphertext blocks
Expand All @@ -2587,17 +2577,18 @@ template <typename Torus> struct int_mul_memory {
int total_block_count = lsb_vector_block_count + msb_vector_block_count;

// allocate memory for intermediate buffers
vector_result_sb = (Torus *)cuda_malloc_async(
2 * total_block_count * (polynomial_size * glwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
block_mul_res = (Torus *)cuda_malloc_async(
2 * total_block_count * (polynomial_size * glwe_dimension + 1) *
sizeof(Torus),
streams[0], gpu_indexes[0]);
small_lwe_vector = (Torus *)cuda_malloc_async(
total_block_count * (lwe_dimension + 1) * sizeof(Torus), streams[0],
gpu_indexes[0]);
vector_result_sb = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], vector_result_sb, 2 * total_block_count,
params.big_lwe_dimension);
block_mul_res = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], block_mul_res, 2 * total_block_count,
params.big_lwe_dimension);
small_lwe_vector = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], small_lwe_vector, total_block_count,
params.small_lwe_dimension);

// create int_radix_lut objects for lsb, msb, message, carry
// luts_array -> lut = {lsb_acc, msb_acc}
Expand Down Expand Up @@ -2658,9 +2649,12 @@ template <typename Torus> struct int_mul_memory {

return;
}
cuda_drop_async(vector_result_sb, streams[0], gpu_indexes[0]);
cuda_drop_async(block_mul_res, streams[0], gpu_indexes[0]);
cuda_drop_async(small_lwe_vector, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], vector_result_sb);
delete vector_result_sb;
release_radix_ciphertext(streams[0], gpu_indexes[0], block_mul_res);
delete block_mul_res;
release_radix_ciphertext(streams[0], gpu_indexes[0], small_lwe_vector);
delete small_lwe_vector;

luts_array->release(streams, gpu_indexes, gpu_count);
sum_ciphertexts_mem->release(streams, gpu_indexes, gpu_count);
Expand Down Expand Up @@ -4435,7 +4429,7 @@ template <typename Torus> struct int_scalar_mul_buffer {
int_radix_params params;
int_logical_scalar_shift_buffer<Torus> *logical_scalar_shift_buffer;
int_sum_ciphertexts_vec_memory<Torus> *sum_ciphertexts_vec_mem;
Torus *preshifted_buffer;
CudaRadixCiphertextFFI *preshifted_buffer;
CudaRadixCiphertextFFI *all_shifted_buffer;
int_sc_prop_memory<Torus> *sc_prop_mem;
bool anticipated_buffers_drop;
Expand All @@ -4450,25 +4444,21 @@ template <typename Torus> struct int_scalar_mul_buffer {

if (allocate_gpu_memory) {
uint32_t msg_bits = (uint32_t)std::log2(params.message_modulus);
uint32_t lwe_size = params.big_lwe_dimension + 1;
uint32_t lwe_size_bytes = lwe_size * sizeof(Torus);
size_t num_ciphertext_bits = msg_bits * num_radix_blocks;

//// Contains all shifted values of lhs for shift in range (0..msg_bits)
//// The idea is that with these we can create all other shift that are
/// in / range (0..total_bits) for free (block rotation)
preshifted_buffer = (Torus *)cuda_malloc_async(
num_ciphertext_bits * lwe_size_bytes, streams[0], gpu_indexes[0]);
preshifted_buffer = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], preshifted_buffer, num_ciphertext_bits,
params.big_lwe_dimension);

all_shifted_buffer = new CudaRadixCiphertextFFI;
create_zero_radix_ciphertext_async<Torus>(
streams[0], gpu_indexes[0], all_shifted_buffer,
num_ciphertext_bits * num_radix_blocks, params.big_lwe_dimension);

cuda_memset_async(preshifted_buffer, 0,
num_ciphertext_bits * lwe_size_bytes, streams[0],
gpu_indexes[0]);

if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2)
logical_scalar_shift_buffer =
new int_logical_scalar_shift_buffer<Torus>(
Expand Down Expand Up @@ -4500,7 +4490,8 @@ template <typename Torus> struct int_scalar_mul_buffer {
release_radix_ciphertext(streams[0], gpu_indexes[0], all_shifted_buffer);
delete all_shifted_buffer;
if (!anticipated_buffers_drop) {
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
release_radix_ciphertext(streams[0], gpu_indexes[0], preshifted_buffer);
delete preshifted_buffer;
logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
delete (logical_scalar_shift_buffer);
}
Expand Down
20 changes: 14 additions & 6 deletions backends/tfhe-cuda-backend/cuda/src/integer/cmux.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@
template <typename Torus>
__host__ void zero_out_if(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
Torus *lwe_array_out, Torus const *lwe_array_input,
Torus const *lwe_condition,
CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_input,
CudaRadixCiphertextFFI const *lwe_condition,
int_zero_out_if_buffer<Torus> *mem_ptr,
int_radix_lut<Torus> *predicate, void *const *bsks,
Torus *const *ksks, uint32_t num_radix_blocks) {
if (lwe_array_out->num_radix_blocks < num_radix_blocks ||
lwe_array_input->num_radix_blocks < num_radix_blocks)
PANIC("Cuda error: input or output radix ciphertexts does not have enough "
"blocks")
if (lwe_array_out->lwe_dimension != lwe_array_input->lwe_dimension ||
lwe_array_input->lwe_dimension != lwe_condition->lwe_dimension)
PANIC("Cuda error: input and output radix ciphertexts must have the same "
"lwe dimension")
cuda_set_device(gpu_indexes[0]);
auto params = mem_ptr->params;

Expand All @@ -21,12 +30,11 @@ __host__ void zero_out_if(cudaStream_t const *streams,
host_pack_bivariate_blocks_with_single_block<Torus>(
streams, gpu_indexes, gpu_count, tmp_lwe_array_input,
predicate->lwe_indexes_in, lwe_array_input, lwe_condition,
predicate->lwe_indexes_in, params.big_lwe_dimension,
params.message_modulus, num_radix_blocks);
predicate->lwe_indexes_in, params.message_modulus, num_radix_blocks);

legacy_integer_radix_apply_univariate_lookup_table_kb<Torus>(
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, lwe_array_out, tmp_lwe_array_input, bsks,
ksks, num_radix_blocks, predicate);
ksks, predicate, num_radix_blocks);
}

template <typename Torus>
Expand Down
19 changes: 15 additions & 4 deletions backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,18 +557,29 @@ __global__ void device_pack_bivariate_blocks_with_single_block(
template <typename Torus>
__host__ void host_pack_bivariate_blocks_with_single_block(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_indexes_out,
Torus const *lwe_array_1, Torus const *lwe_2, Torus const *lwe_indexes_in,
uint32_t lwe_dimension, uint32_t shift, uint32_t num_radix_blocks) {
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array_out,
Torus const *lwe_indexes_out, CudaRadixCiphertextFFI const *lwe_array_1,
CudaRadixCiphertextFFI const *lwe_2, Torus const *lwe_indexes_in,
uint32_t shift, uint32_t num_radix_blocks) {

if (lwe_array_out->num_radix_blocks < num_radix_blocks ||
lwe_array_1->num_radix_blocks < num_radix_blocks)
PANIC("Cuda error: input or output radix ciphertexts does not have enough "
"blocks")
if (lwe_array_out->lwe_dimension != lwe_array_1->lwe_dimension ||
lwe_array_1->lwe_dimension != lwe_2->lwe_dimension)
PANIC("Cuda error: input and output radix ciphertexts must have the same "
"lwe dimension")
auto lwe_dimension = lwe_array_out->lwe_dimension;
cuda_set_device(gpu_indexes[0]);
// Left message is shifted
int num_blocks = 0, num_threads = 0;
int num_entries = num_radix_blocks * (lwe_dimension + 1);
getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads);
device_pack_bivariate_blocks_with_single_block<Torus>
<<<num_blocks, num_threads, 0, streams[0]>>>(
lwe_array_out, lwe_indexes_out, lwe_array_1, lwe_2, lwe_indexes_in,
(Torus *)lwe_array_out->ptr, lwe_indexes_out,
(Torus *)lwe_array_1->ptr, (Torus *)lwe_2->ptr, lwe_indexes_in,
lwe_dimension, shift, num_radix_blocks);
check_cuda_error(cudaGetLastError());
}
Expand Down
Loading
Loading