Skip to content

Commit

Permalink
refactor(hlapi): split long files of hlapi
Browse files Browse the repository at this point in the history
This splits the long base.rs files into multiple ones,
to make it easier to navigate.

There is no code changes appart from moving stuff.
  • Loading branch information
tmontaigu committed Feb 1, 2024
1 parent bce3bf1 commit 48f67fb
Show file tree
Hide file tree
Showing 17 changed files with 7,008 additions and 6,904 deletions.
300 changes: 18 additions & 282 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,180 +2,19 @@ use std::borrow::Borrow;
use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};

use crate::high_level_api::booleans::compressed::CompressedFheBool;
use crate::high_level_api::details::MaybeCloned;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_stream;
use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
use crate::high_level_api::keys::{ClientKey, InternalServerKey, PublicKey};
use crate::high_level_api::traits::{
FheDecrypt, FheEq, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse,
};
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{FheEq, IfThenElse};
use crate::integer::BooleanBlock;
use crate::named::Named;
use crate::shortint::ciphertext::{Degree, NotTrivialCiphertextError};
use crate::{CompactFheBool, CompactPublicKey, CompressedPublicKey, Device};
use serde::{Deserialize, Deserializer, Serialize, Serializer};

pub(in crate::high_level_api) enum InnerBoolean {
Cpu(BooleanBlock),
#[cfg(feature = "gpu")]
Cuda(crate::integer::gpu::ciphertext::CudaRadixCiphertext),
}
use crate::shortint::ciphertext::NotTrivialCiphertextError;
use crate::{CompactFheBool, Device};
use serde::{Deserialize, Serialize};

impl Clone for InnerBoolean {
fn clone(&self) -> Self {
match self {
Self::Cpu(inner) => Self::Cpu(inner.clone()),
#[cfg(feature = "gpu")]
Self::Cuda(inner) => {
with_thread_local_cuda_stream(|stream| Self::Cuda(inner.duplicate(stream)))
}
}
}
}
impl serde::Serialize for InnerBoolean {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
Self::Cpu(cpu_ct) => cpu_ct.serialize(serializer),
#[cfg(feature = "gpu")]
Self::Cuda(_) => self.on_cpu().serialize(serializer),
}
}
}

impl<'de> serde::Deserialize<'de> for InnerBoolean {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let mut deserialized = Self::Cpu(crate::integer::BooleanBlock::deserialize(deserializer)?);
deserialized.move_to_device_of_server_key_if_set();
Ok(deserialized)
}
}

impl From<BooleanBlock> for InnerBoolean {
fn from(value: BooleanBlock) -> Self {
Self::Cpu(value)
}
}

#[cfg(feature = "gpu")]
impl From<crate::integer::gpu::ciphertext::CudaRadixCiphertext> for InnerBoolean {
fn from(value: crate::integer::gpu::ciphertext::CudaRadixCiphertext) -> Self {
Self::Cuda(value)
}
}

impl InnerBoolean {
pub(crate) fn current_device(&self) -> Device {
match self {
Self::Cpu(_) => Device::Cpu,
#[cfg(feature = "gpu")]
Self::Cuda(_) => Device::CudaGpu,
}
}

/// Returns the inner cpu ciphertext if self is on the CPU, otherwise, returns a copy
/// that is on the CPU
pub(crate) fn on_cpu(&self) -> MaybeCloned<'_, BooleanBlock> {
match self {
Self::Cpu(ct) => MaybeCloned::Borrowed(ct),
#[cfg(feature = "gpu")]
Self::Cuda(ct) => with_thread_local_cuda_stream(|stream| {
let cpu_ct = ct.to_radix_ciphertext(stream);
MaybeCloned::Cloned(BooleanBlock::new_unchecked(cpu_ct.blocks[0].clone()))
}),
}
}

/// Returns the inner cpu ciphertext if self is on the CPU, otherwise, returns a copy
/// that is on the CPU
#[cfg(feature = "gpu")]
pub(crate) fn on_gpu(
&self,
) -> MaybeCloned<'_, crate::integer::gpu::ciphertext::CudaRadixCiphertext> {
match self {
Self::Cpu(ct) => with_thread_local_cuda_stream(|stream| {
let ct_as_radix = crate::integer::RadixCiphertext::from(vec![ct.0.clone()]);
let cuda_ct =
crate::integer::gpu::ciphertext::CudaRadixCiphertext::from_radix_ciphertext(
&ct_as_radix,
stream,
);
MaybeCloned::Cloned(cuda_ct)
}),
#[cfg(feature = "gpu")]
Self::Cuda(ct) => MaybeCloned::Borrowed(ct),
}
}

pub(crate) fn as_cpu_mut(&mut self) -> &mut BooleanBlock {
match self {
Self::Cpu(block) => block,
#[cfg(feature = "gpu")]
_ => {
self.move_to_device(Device::Cpu);
self.as_cpu_mut()
}
}
}

#[cfg(feature = "gpu")]
#[track_caller]
pub(crate) fn as_gpu_mut(
&mut self,
) -> &mut crate::integer::gpu::ciphertext::CudaRadixCiphertext {
if let Self::Cuda(radix_ct) = self {
radix_ct
} else {
self.move_to_device(Device::CudaGpu);
self.as_gpu_mut()
}
}

pub(crate) fn move_to_device(&mut self, device: Device) {
match (&self, device) {
(Self::Cpu(_), Device::Cpu) => {
// Nothing to do, we already are on the correct device
}
#[cfg(feature = "gpu")]
(Self::Cuda(_), Device::CudaGpu) => {
// Nothing to do, we already are on the correct device
}
#[cfg(feature = "gpu")]
(Self::Cpu(ct), Device::CudaGpu) => {
let ct_as_radix = crate::integer::RadixCiphertext::from(vec![ct.0.clone()]);
let new_inner = with_thread_local_cuda_stream(|stream| {
crate::integer::gpu::ciphertext::CudaRadixCiphertext::from_radix_ciphertext(
&ct_as_radix,
stream,
)
});
*self = Self::Cuda(new_inner);
}
#[cfg(feature = "gpu")]
(Self::Cuda(ct), Device::Cpu) => {
let new_inner =
with_thread_local_cuda_stream(|stream| ct.to_radix_ciphertext(stream));
*self = Self::Cpu(BooleanBlock::new_unchecked(new_inner.blocks[0].clone()));
}
}
}

#[inline]
#[allow(clippy::unused_self)]
pub(crate) fn move_to_device_of_server_key_if_set(&mut self) {
#[cfg(feature = "gpu")]
if let Some(device) = global_state::device_of_internal_keys() {
self.move_to_device(device);
}
}
}
use super::inner::InnerBoolean;

/// The FHE boolean data type.
///
Expand Down Expand Up @@ -209,6 +48,18 @@ impl Named for FheBool {
const NAME: &'static str = "high_level_api::FheBool";
}

impl From<CompressedFheBool> for FheBool {
fn from(value: CompressedFheBool) -> Self {
value.decompress()
}
}

impl From<CompactFheBool> for FheBool {
fn from(value: CompactFheBool) -> Self {
value.expand()
}
}

impl FheBool {
pub(in crate::high_level_api) fn new<T: Into<InnerBoolean>>(ciphertext: T) -> Self {
Self {
Expand Down Expand Up @@ -516,121 +367,6 @@ impl FheEq<bool> for FheBool {
}
}

impl From<CompressedFheBool> for FheBool {
fn from(value: CompressedFheBool) -> Self {
value.decompress()
}
}

impl From<CompactFheBool> for FheBool {
fn from(value: CompactFheBool) -> Self {
value.expand()
}
}

impl FheTryEncrypt<bool, ClientKey> for FheBool {
type Error = crate::high_level_api::errors::Error;

fn try_encrypt(value: bool, key: &ClientKey) -> Result<Self, Self::Error> {
let integer_client_key = &key.key.key;
let mut ciphertext = Self::new(integer_client_key.encrypt_bool(value));
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
Ok(ciphertext)
}
}

impl FheTryEncrypt<bool, CompactPublicKey> for FheBool {
type Error = crate::high_level_api::errors::Error;

fn try_encrypt(value: bool, key: &CompactPublicKey) -> Result<Self, Self::Error> {
let mut ciphertext = key.key.key.encrypt_radix(value as u8, 1);
ciphertext.blocks[0].degree = Degree::new(1);
Ok(Self::new(BooleanBlock::new_unchecked(
ciphertext.blocks.into_iter().next().unwrap(),
)))
}
}

impl FheTrivialEncrypt<bool> for FheBool {
/// Creates a trivial encryption of a bool.
///
/// # Warning
///
/// Trivial encryptions are not real encryptions, as a trivially encrypted
/// ciphertext can be decrypted by any key (in fact, no key is actually needed).
///
/// Trivial encryptions become real encrypted data once used in an operation
/// that involves a real ciphertext
///
/// # Example
///
/// ```rust
/// # use tfhe::{ConfigBuilder, generate_keys, set_server_key, FheBool};
/// # use tfhe::prelude::*;
/// #
/// # let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let a = FheBool::encrypt_trivial(true);
///
/// let decrypted: bool = a.decrypt(&client_key);
/// assert_eq!(decrypted, true);
/// ```
#[track_caller]
fn encrypt_trivial(value: bool) -> Self {
Self::try_encrypt_trivial(value).unwrap()
}
}

impl FheTryEncrypt<bool, CompressedPublicKey> for FheBool {
type Error = crate::high_level_api::errors::Error;

fn try_encrypt(value: bool, key: &CompressedPublicKey) -> Result<Self, Self::Error> {
let key = &key.key;
let mut ciphertext = Self::new(key.encrypt_bool(value));
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
Ok(ciphertext)
}
}

impl FheTryEncrypt<bool, PublicKey> for FheBool {
type Error = crate::high_level_api::errors::Error;

fn try_encrypt(value: bool, key: &PublicKey) -> Result<Self, Self::Error> {
let key = &key.key;
let mut ciphertext = Self::new(key.encrypt_bool(value));
ciphertext.ciphertext.move_to_device_of_server_key_if_set();
Ok(ciphertext)
}
}

impl FheDecrypt<bool> for FheBool {
/// Decrypts the value
fn decrypt(&self, key: &ClientKey) -> bool {
key.key.key.decrypt_bool(&self.ciphertext.on_cpu())
}
}

impl FheTryTrivialEncrypt<bool> for FheBool {
type Error = crate::high_level_api::errors::Error;

fn try_encrypt_trivial(value: bool) -> Result<Self, Self::Error> {
let ciphertext = global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(key) => {
InnerBoolean::Cpu(key.pbs_key().create_trivial_boolean_block(value))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_stream(|stream| {
let inner = cuda_key
.key
.create_trivial_radix(u64::from(value), 1, stream);
InnerBoolean::Cuda(inner)
}),
});
Ok(Self::new(ciphertext))
}
}

impl<B> BitAnd<B> for FheBool
where
B: Borrow<Self>,
Expand Down
Loading

0 comments on commit 48f67fb

Please sign in to comment.