Skip to content

Commit

Permalink
arithmetic: Avoid heap & simplify alignment logic in `elem_exp_constt…
Browse files Browse the repository at this point in the history
…ime`.

Avoid allocating on the heap. Let the compiler do the alignment
instead of manually aligning the start of the table.
  • Loading branch information
briansmith committed Jan 30, 2025
1 parent c0c9ad9 commit 8041b7e
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 66 deletions.
1 change: 1 addition & 0 deletions src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod constant;
pub mod bigint;

pub(crate) mod inout;
mod limbs512;
pub mod montgomery;

mod n0;
Expand Down
175 changes: 109 additions & 66 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ pub(crate) use self::{
modulusvalue::OwnedModulusValue,
private_exponent::PrivateExponent,
};
use super::{inout::AliasingSlices3, montgomery::*, LimbSliceError, MAX_LIMBS};
use super::{inout::AliasingSlices3, limbs512, montgomery::*, LimbSliceError, MAX_LIMBS};
use crate::{
bits::BitLength,
c,
error::{self, LenMismatchError},
limb::{self, Limb, LIMB_BITS},
polyfill::slice::{self, AsChunks},
};
use alloc::vec;
use core::{
marker::PhantomData,
num::{NonZeroU64, NonZeroUsize},
Expand Down Expand Up @@ -410,20 +410,57 @@ pub(crate) fn elem_exp_vartime<M>(
acc
}

#[cfg(not(target_arch = "x86_64"))]
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, LimbSliceError> {
use crate::{bssl, limb::Window};
// `elem_exp_consttime_inner` is parameterized on `STORAGE_LIMBS` only so
// we can run tests with larger-than-supported-in-operation test vectors.
elem_exp_consttime_inner::<M, { ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS * STORAGE_ENTRIES }>(
base, exponent, m,
)
}

const WINDOW_BITS: usize = 5;
const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
// The maximum modulus size supported for `elem_exp_consttime` in normal
// operation.
const ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: usize = 2048 / LIMB_BITS;
const _LIMBS_PER_CHUNK_DIVIDES_ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: () =
assert!(ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS % limbs512::LIMBS_PER_CHUNK == 0);
const WINDOW_BITS: u32 = 5;
const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
const STORAGE_ENTRIES: usize = TABLE_ENTRIES + if cfg!(target_arch = "x86_64") { 3 } else { 0 };

#[cfg(not(target_arch = "x86_64"))]
fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, LimbSliceError> {
use crate::{bssl, limb::Window};

let num_limbs = m.limbs().len();
let m_chunked: AsChunks<Limb, { limbs512::LIMBS_PER_CHUNK }> = match slice::as_chunks(m.limbs())
{
(m, []) => m,
_ => {
return Err(LimbSliceError::len_mismatch(LenMismatchError::new(
num_limbs,
)))

Check warning on line 449 in src/arithmetic/bigint.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/bigint.rs#L447-L449

Added lines #L447 - L449 were not covered by tests
}
};
let cpe = m_chunked.len(); // 512-bit chunks per entry.

// This code doesn't have the strict alignment requirements that the x86_64
// version does, but uses the same aligned storage for convenience.
assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const`
let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed();
let mut table = table
.aligned_chunks_mut(TABLE_ENTRIES, cpe)
.map_err(LimbSliceError::len_mismatch)?;

let mut table = vec![0; TABLE_ENTRIES * num_limbs];
// TODO: Rewrite the below in terms of `AsChunks`.
let table = table.as_flattened_mut();

fn gather<M>(table: &[Limb], acc: &mut Elem<M, R>, i: Window) {
prefixed_extern! {
Expand Down Expand Up @@ -463,9 +500,9 @@ pub fn elem_exp_consttime<M>(
}

// table[0] = base**0 (i.e. 1).
m.oneR(entry_mut(&mut table, 0, num_limbs));
m.oneR(entry_mut(table, 0, num_limbs));

entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
entry_mut(table, 1, num_limbs).copy_from_slice(&base.limbs);
for i in 2..TABLE_ENTRIES {
let (src1, src2) = if i % 2 == 0 {
(i / 2, i / 2)
Expand Down Expand Up @@ -497,7 +534,7 @@ pub fn elem_exp_consttime<M>(
}

#[cfg(target_arch = "x86_64")]
pub fn elem_exp_consttime<M>(
fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &Modulus<M>,
Expand All @@ -508,71 +545,60 @@ pub fn elem_exp_consttime<M>(
intel::{Adx, Bmi2},
GetFeature as _,
},
limb::LIMB_BYTES,
polyfill::slice::{self, AsChunks, AsChunksMut},
limb::{LeakyWindow, Window},
polyfill::slice::AsChunksMut,
};

let cpu2 = m.cpu_features().get_feature();
let cpu3 = m.cpu_features().get_feature();

// The x86_64 assembly was written under the assumption that the input data
// is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL.
// Subsequently, it was changed such that, according to BoringSSL, they
// only require 16 byte alignment. We enforce the old, stronger, alignment
// unless/until we can see a benefit to reducing it.
//
// Similarly, OpenSSL uses the x86_64 assembly functions by giving it only
// inputs `tmp`, `am`, and `np` that immediately follow the table. All the
// awkwardness here stems from trying to use the assembly code like OpenSSL
// does.

use crate::limb::{LeakyWindow, Window};

const WINDOW_BITS: usize = 5;
const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;

let num_limbs = m.limbs().len();

const ALIGNMENT: usize = 64;
assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
let (table, state) = {
let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
table.split_at_mut(TABLE_ENTRIES * num_limbs)
// inputs `tmp`, `am`, and `np` that immediately follow the table.
// According to BoringSSL, in older versions of the OpenSSL code, this
// extra space was required for memory safety because the assembly code
// would over-read the table; according to BoringSSL, this is no longer the
// case. Regardless, the upstream code also contained comments implying
// that this was also important for performance. For now, we do as OpenSSL
// did/does.
const MOD_EXP_CTIME_ALIGN: usize = 64;
// Required by
const _TABLE_ENTRIES_IS_32: () = assert!(TABLE_ENTRIES == 32);
const _STORAGE_ENTRIES_HAS_3_EXTRA: () = assert!(STORAGE_ENTRIES == TABLE_ENTRIES + 3);

let m_original: AsChunks<Limb, 8> = match slice::as_chunks(m.limbs()) {
(m, []) => m,
_ => return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8))),

Check warning on line 576 in src/arithmetic/bigint.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/bigint.rs#L576

Added line #L576 was not covered by tests
};
let cpe = m_original.len(); // 512-bit chunks per entry.

// These are named `(tmp, am, np)` in BoringSSL.
let (acc, base_cached, m_cached): (&mut [Limb], &[Limb], &[Limb]) = {
let (acc, rest) = state.split_at_mut(num_limbs);
let (base_cached, rest) = rest.split_at_mut(num_limbs);

// Upstream, the input `base` is not Montgomery-encoded, so they compute a
// Montgomery-encoded copy and store it here.
base_cached.copy_from_slice(&base.limbs);
assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const`
let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed();
let mut table = table
.aligned_chunks_mut(STORAGE_ENTRIES, cpe)
.map_err(LimbSliceError::len_mismatch)?;
let (mut table, mut state) = table.split_at_mut(TABLE_ENTRIES * cpe);
assert_eq!((table.as_ptr() as usize) % MOD_EXP_CTIME_ALIGN, 0);

let m_cached = &mut rest[..num_limbs];
// "To improve cache locality" according to upstream.
m_cached.copy_from_slice(m.limbs());
// These are named `(tmp, am, np)` in BoringSSL.
let (mut acc, mut rest) = state.split_at_mut(cpe);
let (mut base_cached, mut m_cached) = rest.split_at_mut(cpe);

(acc, base_cached, m_cached)
};
// Upstream, the input `base` is not Montgomery-encoded, so they compute a
// Montgomery-encoded copy and store it here.
base_cached.as_flattened_mut().copy_from_slice(&base.limbs);
let base_cached = base_cached.as_ref();

let n0 = m.n0();

// TODO: Move the use of `Chunks`/`ChunksMut` up into the signature of the
// function so this conversion isn't necessary.
let (mut table, mut acc, base_cached, m_cached) = match (
slice::as_chunks_mut(table),
slice::as_chunks_mut(acc),
slice::as_chunks(base_cached),
slice::as_chunks(m_cached),
) {
((table, []), (acc, []), (base_cached, []), (m_cached, [])) => {
(table, acc, base_cached, m_cached)
}
_ => {
// XXX: Not the best error to return
return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8)));
}
};
// "To improve cache locality" according to upstream.
m_cached
.as_flattened_mut()
.copy_from_slice(m_original.as_flattened());
let m_cached = m_cached.as_ref();

// Fill in all the powers of 2 of `acc` into the table using only squaring and without any
// gathering, storing the last calculated power into `acc`.
Expand Down Expand Up @@ -605,6 +631,8 @@ pub fn elem_exp_consttime<M>(
acc.as_flattened_mut()
.copy_from_slice(base_cached.as_flattened());

let n0 = m.n0();

// Fill in entries 1, 2, 4, 8, 16.
scatter_powers_of_2(table.as_mut(), acc.as_mut(), m_cached, n0, 1, cpu2)?;
// Fill in entries 3, 6, 12, 24; 5, 10, 20, 30; 7, 14, 28; 9, 18; 11, 22; 13, 26; 15, 30;
Expand Down Expand Up @@ -715,10 +743,25 @@ mod tests {
.expect("valid exponent")
};
let base = into_encoded(base, &m);
let actual_result = elem_exp_consttime(base, &e, &m)
.map_err(error::erase::<LimbSliceError>)
.unwrap();
assert_elem_eq(&actual_result, &expected_result);

let too_big = m.limbs().len() > ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS;
let actual_result = if !too_big {
elem_exp_consttime(base, &e, &m)
} else {
let actual_result = elem_exp_consttime(base.clone(), &e, &m);
// TODO: Be more specific with which error we expect?
assert!(actual_result.is_err());
// Try again with a larger-than-normally-supported limit
elem_exp_consttime_inner::<_, { (4096 / LIMB_BITS) * STORAGE_ENTRIES }>(
base, &e, &m,
)
};
match actual_result {

Check warning on line 759 in src/arithmetic/bigint.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/bigint.rs#L759

Added line #L759 was not covered by tests
Ok(r) => assert_elem_eq(&r, &expected_result),
Err(LimbSliceError::LenMismatch { .. }) => panic!(),
Err(LimbSliceError::TooLong { .. }) => panic!(),
Err(LimbSliceError::TooShort { .. }) => panic!(),

Check warning on line 763 in src/arithmetic/bigint.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/bigint.rs#L761-L763

Added lines #L761 - L763 were not covered by tests
};

Ok(())
},
Expand Down
17 changes: 17 additions & 0 deletions src/arithmetic/limbs512/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright 2025 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

mod storage;

pub(super) use self::storage::{AlignedStorage, LIMBS_PER_CHUNK};
60 changes: 60 additions & 0 deletions src/arithmetic/limbs512/storage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2025 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

use crate::{
error::LenMismatchError,
limb::{Limb, LIMB_BITS},
polyfill::slice::{self, AsChunksMut},
};
use core::mem::{align_of, size_of};

// Some x86_64 assembly is written under the assumption that some of its
// input data and/or temporary storage is aligned to `MOD_EXP_CTIME_ALIGN`
// bytes, which was/is 64 in OpenSSL.
//
// We use this in the non-X86-64 implementation of exponentiation as well,
// with the hope of converging th two implementations into one.

#[repr(C, align(64))]
pub struct AlignedStorage<const N: usize>([Limb; N]);

const _LIMB_SIZE_DIVIDES_ALIGNMENT: () =
assert!(align_of::<AlignedStorage<1>>() % size_of::<Limb>() == 0);

pub const LIMBS_PER_CHUNK: usize = 512 / LIMB_BITS;

impl<const N: usize> AlignedStorage<N> {
pub fn zeroed() -> Self {
assert_eq!(N % LIMBS_PER_CHUNK, 0); // TODO: const.
Self([0; N])
}

// The result will have every chunk aligned on a 64 byte boundary.
pub fn aligned_chunks_mut(
&mut self,
num_entries: usize,
chunks_per_entry: usize,
) -> Result<AsChunksMut<Limb, LIMBS_PER_CHUNK>, LenMismatchError> {
let total_limbs = num_entries * chunks_per_entry * LIMBS_PER_CHUNK;
let len = self.0.len();
let flattened = self
.0
.get_mut(..total_limbs)
.ok_or_else(|| LenMismatchError::new(len))?;
match slice::as_chunks_mut(flattened) {
(chunks, []) => Ok(chunks),
(_, r) => Err(LenMismatchError::new(r.len())),

Check warning on line 57 in src/arithmetic/limbs512/storage.rs

View check run for this annotation

Codecov / codecov/patch

src/arithmetic/limbs512/storage.rs#L57

Added line #L57 was not covered by tests
}
}
}
12 changes: 12 additions & 0 deletions src/polyfill/slice/as_chunks_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> {
self.0.as_ptr().cast()
}

#[cfg(target_arch = "x86_64")]
pub fn as_ptr(&self) -> *const [T; N] {
self.0.as_ptr().cast()
}

#[cfg(target_arch = "aarch64")]
pub fn as_mut_ptr(&mut self) -> *mut [T; N] {
self.0.as_mut_ptr().cast()
Expand All @@ -62,6 +67,13 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> {
pub fn chunks_mut<const CHUNK_LEN: usize>(&mut self) -> AsChunksMutChunksMutIter<T, N> {
AsChunksMutChunksMutIter(self.0.chunks_mut(CHUNK_LEN * N))
}

#[cfg(target_arch = "x86_64")]
#[inline(always)]
pub fn split_at_mut(&mut self, mid: usize) -> (AsChunksMut<T, N>, AsChunksMut<T, N>) {
let (before, after) = self.0.split_at_mut(mid * N);
(AsChunksMut(before), AsChunksMut(after))
}
}

pub struct AsChunksMutChunksMutIter<'a, T, const N: usize>(core::slice::ChunksMut<'a, T>);
Expand Down

0 comments on commit 8041b7e

Please sign in to comment.