Skip to content

Commit

Permalink
chore(gpu): update gf4 parameters and make them default in benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Feb 27, 2025
1 parent 8565b79 commit 8d16460
Show file tree
Hide file tree
Showing 36 changed files with 276 additions and 267 deletions.
2 changes: 1 addition & 1 deletion tfhe/benches/high_level_api/erc20.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ fn main() {
#[cfg(not(feature = "gpu"))]
let params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
#[cfg(feature = "gpu")]
let params = PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
let params = PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS;

let config = ConfigBuilder::with_custom_parameters(params).build();
let cks = ClientKey::generate(config);
Expand Down
4 changes: 2 additions & 2 deletions tfhe/benches/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub mod shortint_utils {
use tfhe::shortint::parameters::current_params::V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
use tfhe::shortint::parameters::list_compression::CompressionParameters;
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS;
use tfhe::shortint::parameters::{
ShortintKeySwitchingParameters, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
};
Expand All @@ -63,7 +63,7 @@ pub mod shortint_utils {

if env_config.is_multi_bit {
#[cfg(feature = "gpu")]
let params = vec![PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS.into()];
let params = vec![PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS.into()];
#[cfg(not(feature = "gpu"))]
let params = vec![
V1_0_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64.into(),
Expand Down
20 changes: 10 additions & 10 deletions tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS;
use crate::shortint::{ClassicPBSParameters, PBSParameters};
use crate::{set_server_key, ClientKey, ConfigBuilder};

Expand Down Expand Up @@ -34,7 +34,7 @@ fn test_uint8_quickstart_gpu() {

#[test]
fn test_uint8_quickstart_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_uint8_quickstart(&client_key);
}

Expand All @@ -46,7 +46,7 @@ fn test_uint32_quickstart_gpu() {

#[test]
fn test_uint32_quickstart_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_uint32_quickstart(&client_key);
}

Expand All @@ -58,7 +58,7 @@ fn test_uint64_quickstart_gpu() {

#[test]
fn test_uint64_quickstart_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_uint64_quickstart(&client_key);
}

Expand All @@ -82,7 +82,7 @@ fn test_uint32_bitwise_gpu() {

#[test]
fn test_uint32_bitwise_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_uint32_bitwise(&client_key);
}

Expand All @@ -94,7 +94,7 @@ fn test_if_then_else_gpu() {

#[test]
fn test_if_then_else_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_if_then_else(&client_key);
}

Expand All @@ -106,7 +106,7 @@ fn test_sum_gpu() {

#[test]
fn test_sum_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_sum(&client_key);
}

Expand All @@ -118,7 +118,7 @@ fn test_is_even_is_odd_gpu() {

#[test]
fn test_is_even_is_odd_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_is_even_is_odd(&client_key);
}

Expand All @@ -130,7 +130,7 @@ fn test_leading_trailing_zeros_ones_gpu() {

#[test]
fn test_leading_trailing_zeros_ones_gpu_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_leading_trailing_zeros_ones(&client_key);
}

Expand All @@ -142,6 +142,6 @@ fn test_ilog2_gpu() {

#[test]
fn test_ilog2_multibit() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS));
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_4_KS_PBS));
super::test_case_ilog2(&client_key);
}
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/keys/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl Default for IntegerConfig {
crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into();
#[cfg(feature = "gpu")]
let params =
crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.into();
Self {
block_parameters: params,
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ mod tests {
use crate::shortint::parameters::{
// TODO GPU DRIFT UPDATE
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use crate::shortint::ShortintParameterSet;
Expand Down Expand Up @@ -704,7 +704,7 @@ mod tests {
for params in [
// TODO GPU DRIFT UPDATE
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64.into(),
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64.into(),
] {
let (radix_cks, sks) =
gen_keys_radix_gpu::<ShortintParameterSet>(params, NUM_BLOCKS, &streams);
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,14 @@ where
/// use tfhe::core_crypto::gpu::vec::GpuIndex;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
/// // generate the client key and the server key:
/// let num_blocks = 4;
/// # // TODO GPU DRIFT UPDATE
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// ```
pub fn gen_keys_radix_gpu<P>(
parameters_set: P,
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/integer/gpu/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::CudaServerKey;
/// use tfhe::integer::ClientKey;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key:
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// // Generate the server key:
/// let sks = CudaServerKey::new(&cks, &streams);
Expand Down Expand Up @@ -172,13 +172,13 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::CudaServerKey;
/// use tfhe::integer::{ClientKey, CompressedServerKey, ServerKey};
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
/// let size = 4;
/// # // TODO GPU DRIFT UPDATE
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let cks = ClientKey::new(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let compressed_sks = CompressedServerKey::new_radix_compressed_server_key(&cks);
/// let cuda_sks = CudaServerKey::decompress_from_cpu(&compressed_sks, &streams);
/// let cpu_sks = compressed_sks.decompress();
Expand Down
4 changes: 2 additions & 2 deletions tfhe/src/integer/gpu/server_key/radix/abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::ciphertext::CudaSignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
///
/// let msg = -14i32;
///
Expand Down
16 changes: 8 additions & 8 deletions tfhe/src/integer/gpu/server_key/radix/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// # // TODO GPU DRIFT UPDATE
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
///
/// let msg1 = 14;
/// let msg2 = 97;
Expand Down Expand Up @@ -139,15 +139,15 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
///
/// let msg1 = 10;
/// let msg2 = 127;
Expand Down Expand Up @@ -423,15 +423,15 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let total_bits = num_blocks * cks.parameters().message_modulus().0.ilog2() as usize;
/// let modulus = 1 << total_bits;
///
Expand Down Expand Up @@ -601,15 +601,15 @@ impl CudaServerKey {
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
/// # // TODO GPU DRIFT UPDATE
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let gpu_index = 0;
/// let streams = CudaStreams::new_single_gpu(GpuIndex::new(gpu_index));
///
/// # // TODO GPU DRIFT UPDATE
/// // Generate the client key and the server key:
/// let num_blocks = 4;
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, num_blocks, &streams);
/// let total_bits = num_blocks * cks.parameters().message_modulus().0.ilog2() as usize;
/// let modulus = 1 << total_bits;
///
Expand Down
Loading

0 comments on commit 8d16460

Please sign in to comment.