Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address clippy warnings #6

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,7 @@ impl Client {
}

let masked = rounded_res & mat_elem_mask;
let unmasked = masked.wrapping_add(binary_fuse_filter::mix(hash, idx as u64) as u32) & mat_elem_mask;

unmasked
masked.wrapping_add(binary_fuse_filter::mix(hash, idx as u64) as u32) & mat_elem_mask
})
.collect::<Vec<u32>>();

Expand Down
54 changes: 26 additions & 28 deletions src/pir_internals/binary_fuse_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub struct BinaryFuseFilter {
pub mat_elem_bit_len: usize,
}

type BinaryFuseFilterIntermediateStageResult<'a> = (BinaryFuseFilter, Vec<u64>, Vec<u8>, HashMap<u64, &'a [u8]>);

impl BinaryFuseFilter {
/// Constructs a 3-wise xor Binary Fuse Filter. This implementation collects inspiration from https://github.com/FastFilter/fastfilter_cpp/blob/5df1dc5063702945f6958e4bda445dd082aed366/src/xorfilter/3wise_xor_binary_fuse_filter_lowmem.h.
///
Expand All @@ -32,7 +34,7 @@ impl BinaryFuseFilter {
db: &HashMap<&'a [u8], &[u8]>,
mat_elem_bit_len: usize,
max_attempt_count: usize,
) -> Option<(BinaryFuseFilter, Vec<u64>, Vec<u8>, HashMap<u64, &'a [u8]>)> {
) -> Option<BinaryFuseFilterIntermediateStageResult<'a>> {
const ARITY: u32 = 3;

let db_size = db.len();
Expand All @@ -45,11 +47,11 @@ impl BinaryFuseFilter {
let size_factor = size_factor::<ARITY>(db_size as u32);
let capacity = if db_size > 1 { ((db_size as f64) * size_factor).round() as u32 } else { 0 };

let init_segment_count = (capacity + segment_length - 1) / segment_length;
let init_segment_count = capacity.div_ceil(segment_length);
let (num_fingerprints, segment_count) = {
let array_len = init_segment_count * segment_length;
let segment_count: u32 = {
let proposed = (array_len + segment_length - 1) / segment_length;
let proposed = array_len.div_ceil(segment_length);
if proposed < ARITY {
1
} else {
Expand Down Expand Up @@ -93,8 +95,8 @@ impl BinaryFuseFilter {
for _ in 0..max_attempt_count {
rng.fill_bytes(&mut seed);

for i in 0..start_pos_len {
start_pos[i] = (((i as u64) * (db_size as u64)) >> block_bits) as usize;
for (idx, val) in start_pos.iter_mut().enumerate() {
*val = (((idx as u64) * (db_size as u64)) >> block_bits) as usize;
}

for &key in db.keys() {
Expand All @@ -114,9 +116,7 @@ impl BinaryFuseFilter {
}

let mut error = false;
for i in 0..db_size {
let hash = reverse_order[i];

for &hash in reverse_order.iter().take(db_size) {
let (h0, h1, h2) = hash_batch_for_3_wise_xor_filter(hash, segment_length, segment_count_length);
let (h0, h1, h2) = (h0 as usize, h1 as usize, h2 as usize);

Expand All @@ -143,9 +143,9 @@ impl BinaryFuseFilter {
}

let mut qsize = 0;
for i in 0..num_fingerprints {
alone[qsize] = i as u32;
if (t2count[i] >> 2) == 1 {
for (idx, &count) in t2count.iter().enumerate().take(num_fingerprints) {
alone[qsize] = idx as u32;
if (count >> 2) == 1 {
qsize += 1;
}
}
Expand Down Expand Up @@ -240,7 +240,7 @@ impl BinaryFuseFilter {
db: &HashMap<&'a [u8], &[u8]>,
mat_elem_bit_len: usize,
max_attempt_count: usize,
) -> Option<(BinaryFuseFilter, Vec<u64>, Vec<u8>, HashMap<u64, &'a [u8]>)> {
) -> Option<BinaryFuseFilterIntermediateStageResult<'a>> {
const ARITY: u32 = 4;

let db_size = db.len();
Expand All @@ -253,11 +253,11 @@ impl BinaryFuseFilter {
let size_factor = size_factor::<ARITY>(db_size as u32);
let capacity = if db_size > 1 { ((db_size as f64) * size_factor).round() as u32 } else { 0 };

let init_segment_count = (capacity + segment_length - 1) / segment_length;
let init_segment_count = capacity.div_ceil(segment_length);
let (num_fingerprints, segment_count) = {
let array_len = init_segment_count * segment_length;
let segment_count: u32 = {
let proposed = (array_len + segment_length - 1) / segment_length;
let proposed = array_len.div_ceil(segment_length);
if proposed < ARITY {
1
} else {
Expand Down Expand Up @@ -301,8 +301,8 @@ impl BinaryFuseFilter {
for _ in 0..max_attempt_count {
rng.fill_bytes(&mut seed);

for i in 0..start_pos_len {
start_pos[i] = (((i as u64) * (db_size as u64)) >> block_bits) as usize;
for (idx, val) in start_pos.iter_mut().enumerate().take(start_pos_len) {
*val = (((idx as u64) * (db_size as u64)) >> block_bits) as usize;
}

for &key in db.keys() {
Expand All @@ -322,9 +322,7 @@ impl BinaryFuseFilter {
}

let mut count_mask = 0u8;
for i in 0..db_size {
let hash = reverse_order[i];

for &hash in reverse_order.iter().take(db_size) {
let (h0, h1, h2, h3) = hash_batch_for_4_wise_xor_filter(hash, segment_length, segment_count_length);
let (h0, h1, h2, h3) = (h0 as usize, h1 as usize, h2 as usize, h3 as usize);

Expand All @@ -333,17 +331,17 @@ impl BinaryFuseFilter {
count_mask |= t2count[h0];

t2count[h1] += 4;
t2count[h1] ^= 1 as u8;
t2count[h1] ^= 1u8;
t2hash[h1] ^= hash;
count_mask |= t2count[h1];

t2count[h2] += 4;
t2count[h2] ^= 2 as u8;
t2count[h2] ^= 2u8;
t2hash[h2] ^= hash;
count_mask |= t2count[h2];

t2count[h3] += 4;
t2count[h3] ^= 3 as u8;
t2count[h3] ^= 3u8;
t2hash[h3] ^= hash;
count_mask |= t2count[h3];
}
Expand All @@ -357,9 +355,9 @@ impl BinaryFuseFilter {
}

let mut qsize = 0;
for i in 0..num_fingerprints {
alone[qsize] = i as u32;
if (t2count[i] >> 2) == 1 {
for (idx, &count) in t2count.iter().enumerate().take(num_fingerprints) {
alone[qsize] = idx as u32;
if (count >> 2) == 1 {
qsize += 1;
}
}
Expand Down Expand Up @@ -532,15 +530,15 @@ pub fn hash_of_key(key: &[u8]) -> [u64; 4] {
}

#[inline(always)]
pub fn mix256<'a>(key: &[u64; 4], seed: &[u8; 32]) -> u64 {
pub fn mix256(key: &[u64; 4], seed: &[u8; 32]) -> u64 {
let seed_words = [
u64::from_le_bytes(seed[..8].try_into().unwrap()),
u64::from_le_bytes(seed[8..16].try_into().unwrap()),
u64::from_le_bytes(seed[16..24].try_into().unwrap()),
u64::from_le_bytes(seed[24..].try_into().unwrap()),
];

key.into_iter()
key.iter()
.map(|&k| {
seed_words
.into_iter()
Expand Down Expand Up @@ -574,7 +572,7 @@ pub const fn hash_batch_for_4_wise_xor_filter(hash: u64, segment_length: u32, se
let mut h2 = h1 + segment_length;
let mut h3 = h2 + segment_length;

h1 ^= ((hash >> 0) as u32) & segment_length_mask;
h1 ^= (hash as u32) & segment_length_mask;
h2 ^= ((hash >> 16) as u32) & segment_length_mask;
h3 ^= ((hash >> 32) as u32) & segment_length_mask;

Expand Down
11 changes: 5 additions & 6 deletions src/pir_internals/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ impl Matrix {
let mut res = Matrix::new(self.cols, self.rows).unwrap();

(0..self.cols)
.map(|ridx| (0..self.rows).map(move |cidx| (ridx, cidx)))
.flatten()
.flat_map(|ridx| (0..self.rows).map(move |cidx| (ridx, cidx)))
.for_each(|(ridx, cidx)| {
res[(ridx, cidx)] = self[(cidx, ridx)];
});
Expand Down Expand Up @@ -359,7 +358,7 @@ impl Matrix {

let row = serialization::encode_kv_as_row(key, value, mat_elem_bit_len, cols);

let mat_row_idx0 = h012[found + 0] as usize;
let mat_row_idx0 = h012[found] as usize;
let mat_row_idx1 = h012[found + 1] as usize;
let mat_row_idx2 = h012[found + 2] as usize;

Expand Down Expand Up @@ -494,7 +493,7 @@ impl Matrix {

let row = serialization::encode_kv_as_row(key, value, mat_elem_bit_len, cols);

let mat_row_idx0 = h0123[found + 0] as usize;
let mat_row_idx0 = h0123[found] as usize;
let mat_row_idx1 = h0123[found + 1] as usize;
let mat_row_idx2 = h0123[found + 2] as usize;
let mat_row_idx3 = h0123[found + 3] as usize;
Expand Down Expand Up @@ -632,7 +631,7 @@ impl Mul for Matrix {
}
}

impl<'a, 'b> Mul<&'b Matrix> for &'a Matrix {
impl<'b> Mul<&'b Matrix> for &Matrix {
type Output = Option<Matrix>;

fn mul(self, rhs: &'b Matrix) -> Self::Output {
Expand Down Expand Up @@ -662,7 +661,7 @@ impl Add for Matrix {
}
}

impl<'a, 'b> Add<&'b Matrix> for &'a Matrix {
impl<'b> Add<&'b Matrix> for &Matrix {
type Output = Option<Matrix>;

fn add(self, rhs: &'b Matrix) -> Self::Output {
Expand Down
8 changes: 4 additions & 4 deletions src/pir_internals/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ pub fn u64_from_le_bytes(bytes: &[u8]) -> u64 {
let mut word = 0;
let readable_num_bytes = min(bytes.len(), std::mem::size_of::<u64>());

for i in 0..readable_num_bytes {
word |= (bytes[i] as u64) << (i * 8);
for (idx, &byte) in bytes.iter().enumerate().take(readable_num_bytes) {
word |= (byte as u64) << (idx * 8);
}

word
Expand All @@ -220,8 +220,8 @@ pub fn u64_from_le_bytes(bytes: &[u8]) -> u64 {
pub fn u64_to_le_bytes(word: u64, bytes: &mut [u8]) {
let writable_num_bytes = min(bytes.len(), std::mem::size_of::<u64>());

for i in 0..writable_num_bytes {
bytes[i] = (word >> i * 8) as u8;
for (idx, byte) in bytes.iter_mut().enumerate().take(writable_num_bytes) {
*byte = (word >> (idx * 8)) as u8;
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::pir_internals::{
matrix::Matrix,
params::{LWE_DIMENSION, SERVER_SETUP_MAX_ATTEMPT_COUNT},
};
use std::{collections::HashMap, u32};
use std::collections::HashMap;

#[derive(Clone)]
pub struct Server {
Expand Down