From 8041b7e4f7dc616a5bf1c1a38079de40a64221de Mon Sep 17 00:00:00 2001 From: Brian Smith <brian@briansmith.org> Date: Mon, 27 Jan 2025 13:53:32 -0800 Subject: [PATCH] arithmetic: Avoid heap & simplify alignment logic in `elem_exp_consttime`. Avoid allocating on the heap. Let the compiler do the alignment instead of manually aligning the start of the table. --- src/arithmetic.rs | 1 + src/arithmetic/bigint.rs | 175 +++++++++++++++++----------- src/arithmetic/limbs512/mod.rs | 17 +++ src/arithmetic/limbs512/storage.rs | 60 ++++++++++ src/polyfill/slice/as_chunks_mut.rs | 12 ++ 5 files changed, 199 insertions(+), 66 deletions(-) create mode 100644 src/arithmetic/limbs512/mod.rs create mode 100644 src/arithmetic/limbs512/storage.rs diff --git a/src/arithmetic.rs b/src/arithmetic.rs index f810741f09..3242dc41a9 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -26,6 +26,7 @@ mod constant; pub mod bigint; pub(crate) mod inout; +mod limbs512; pub mod montgomery; mod n0; diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index f948a23cbd..13335d758e 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -42,14 +42,14 @@ pub(crate) use self::{ modulusvalue::OwnedModulusValue, private_exponent::PrivateExponent, }; -use super::{inout::AliasingSlices3, montgomery::*, LimbSliceError, MAX_LIMBS}; +use super::{inout::AliasingSlices3, limbs512, montgomery::*, LimbSliceError, MAX_LIMBS}; use crate::{ bits::BitLength, c, error::{self, LenMismatchError}, limb::{self, Limb, LIMB_BITS}, + polyfill::slice::{self, AsChunks}, }; -use alloc::vec; use core::{ marker::PhantomData, num::{NonZeroU64, NonZeroUsize}, @@ -410,20 +410,57 @@ pub(crate) fn elem_exp_vartime<M>( acc } -#[cfg(not(target_arch = "x86_64"))] pub fn elem_exp_consttime<M>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, ) -> Result<Elem<M, Unencoded>, LimbSliceError> { - use crate::{bssl, limb::Window}; + // `elem_exp_consttime_inner` is parameterized on `STORAGE_LIMBS` only so + // we can run tests with larger-than-supported-in-operation test vectors. + elem_exp_consttime_inner::<M, { ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS * STORAGE_ENTRIES }>( + base, exponent, m, + ) +} - const WINDOW_BITS: usize = 5; - const TABLE_ENTRIES: usize = 1 << WINDOW_BITS; +// The maximum modulus size supported for `elem_exp_consttime` in normal +// operation. +const ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: usize = 2048 / LIMB_BITS; +const _LIMBS_PER_CHUNK_DIVIDES_ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: () = + assert!(ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS % limbs512::LIMBS_PER_CHUNK == 0); +const WINDOW_BITS: u32 = 5; +const TABLE_ENTRIES: usize = 1 << WINDOW_BITS; +const STORAGE_ENTRIES: usize = TABLE_ENTRIES + if cfg!(target_arch = "x86_64") { 3 } else { 0 }; + +#[cfg(not(target_arch = "x86_64"))] +fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>( + base: Elem<M, R>, + exponent: &PrivateExponent, + m: &Modulus<M>, +) -> Result<Elem<M, Unencoded>, LimbSliceError> { + use crate::{bssl, limb::Window}; let num_limbs = m.limbs().len(); + let m_chunked: AsChunks<Limb, { limbs512::LIMBS_PER_CHUNK }> = match slice::as_chunks(m.limbs()) + { + (m, []) => m, + _ => { + return Err(LimbSliceError::len_mismatch(LenMismatchError::new( + num_limbs, + ))) + } + }; + let cpe = m_chunked.len(); // 512-bit chunks per entry. + + // This code doesn't have the strict alignment requirements that the x86_64 + // version does, but uses the same aligned storage for convenience. + assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const` + let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed(); + let mut table = table + .aligned_chunks_mut(TABLE_ENTRIES, cpe) + .map_err(LimbSliceError::len_mismatch)?; - let mut table = vec![0; TABLE_ENTRIES * num_limbs]; + // TODO: Rewrite the below in terms of `AsChunks`. + let table = table.as_flattened_mut(); fn gather<M>(table: &[Limb], acc: &mut Elem<M, R>, i: Window) { prefixed_extern! { @@ -463,9 +500,9 @@ pub fn elem_exp_consttime<M>( } // table[0] = base**0 (i.e. 1). - m.oneR(entry_mut(&mut table, 0, num_limbs)); + m.oneR(entry_mut(table, 0, num_limbs)); - entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs); + entry_mut(table, 1, num_limbs).copy_from_slice(&base.limbs); for i in 2..TABLE_ENTRIES { let (src1, src2) = if i % 2 == 0 { (i / 2, i / 2) @@ -497,7 +534,7 @@ pub fn elem_exp_consttime<M>( } #[cfg(target_arch = "x86_64")] -pub fn elem_exp_consttime<M>( +fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>( base: Elem<M, R>, exponent: &PrivateExponent, m: &Modulus<M>, @@ -508,8 +545,8 @@ pub fn elem_exp_consttime<M>( intel::{Adx, Bmi2}, GetFeature as _, }, - limb::LIMB_BYTES, - polyfill::slice::{self, AsChunks, AsChunksMut}, + limb::{LeakyWindow, Window}, + polyfill::slice::AsChunksMut, }; let cpu2 = m.cpu_features().get_feature(); @@ -517,62 +554,51 @@ pub fn elem_exp_consttime<M>( // The x86_64 assembly was written under the assumption that the input data // is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL. + // Subsequently, it was changed such that, according to BoringSSL, they + // only require 16 byte alignment. We enforce the old, stronger, alignment + // unless/until we can see a benefit to reducing it. + // // Similarly, OpenSSL uses the x86_64 assembly functions by giving it only - // inputs `tmp`, `am`, and `np` that immediately follow the table. All the - // awkwardness here stems from trying to use the assembly code like OpenSSL - // does. - - use crate::limb::{LeakyWindow, Window}; - - const WINDOW_BITS: usize = 5; - const TABLE_ENTRIES: usize = 1 << WINDOW_BITS; - - let num_limbs = m.limbs().len(); - - const ALIGNMENT: usize = 64; - assert_eq!(ALIGNMENT % LIMB_BYTES, 0); - let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT]; - let (table, state) = { - let misalignment = (table.as_ptr() as usize) % ALIGNMENT; - let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..]; - assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0); - table.split_at_mut(TABLE_ENTRIES * num_limbs) + // inputs `tmp`, `am`, and `np` that immediately follow the table. + // According to BoringSSL, in older versions of the OpenSSL code, this + // extra space was required for memory safety because the assembly code + // would over-read the table; according to BoringSSL, this is no longer the + // case. Regardless, the upstream code also contained comments implying + // that this was also important for performance. For now, we do as OpenSSL + // did/does. + const MOD_EXP_CTIME_ALIGN: usize = 64; + // Required by + const _TABLE_ENTRIES_IS_32: () = assert!(TABLE_ENTRIES == 32); + const _STORAGE_ENTRIES_HAS_3_EXTRA: () = assert!(STORAGE_ENTRIES == TABLE_ENTRIES + 3); + + let m_original: AsChunks<Limb, 8> = match slice::as_chunks(m.limbs()) { + (m, []) => m, + _ => return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8))), }; + let cpe = m_original.len(); // 512-bit chunks per entry. - // These are named `(tmp, am, np)` in BoringSSL. - let (acc, base_cached, m_cached): (&mut [Limb], &[Limb], &[Limb]) = { - let (acc, rest) = state.split_at_mut(num_limbs); - let (base_cached, rest) = rest.split_at_mut(num_limbs); - - // Upstream, the input `base` is not Montgomery-encoded, so they compute a - // Montgomery-encoded copy and store it here. - base_cached.copy_from_slice(&base.limbs); + assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const` + let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed(); + let mut table = table + .aligned_chunks_mut(STORAGE_ENTRIES, cpe) + .map_err(LimbSliceError::len_mismatch)?; + let (mut table, mut state) = table.split_at_mut(TABLE_ENTRIES * cpe); + assert_eq!((table.as_ptr() as usize) % MOD_EXP_CTIME_ALIGN, 0); - let m_cached = &mut rest[..num_limbs]; - // "To improve cache locality" according to upstream. - m_cached.copy_from_slice(m.limbs()); + // These are named `(tmp, am, np)` in BoringSSL. + let (mut acc, mut rest) = state.split_at_mut(cpe); + let (mut base_cached, mut m_cached) = rest.split_at_mut(cpe); - (acc, base_cached, m_cached) - }; + // Upstream, the input `base` is not Montgomery-encoded, so they compute a + // Montgomery-encoded copy and store it here. + base_cached.as_flattened_mut().copy_from_slice(&base.limbs); + let base_cached = base_cached.as_ref(); - let n0 = m.n0(); - - // TODO: Move the use of `Chunks`/`ChunksMut` up into the signature of the - // function so this conversion isn't necessary. - let (mut table, mut acc, base_cached, m_cached) = match ( - slice::as_chunks_mut(table), - slice::as_chunks_mut(acc), - slice::as_chunks(base_cached), - slice::as_chunks(m_cached), - ) { - ((table, []), (acc, []), (base_cached, []), (m_cached, [])) => { - (table, acc, base_cached, m_cached) - } - _ => { - // XXX: Not the best error to return - return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8))); - } - }; + // "To improve cache locality" according to upstream. + m_cached + .as_flattened_mut() + .copy_from_slice(m_original.as_flattened()); + let m_cached = m_cached.as_ref(); // Fill in all the powers of 2 of `acc` into the table using only squaring and without any // gathering, storing the last calculated power into `acc`. @@ -605,6 +631,8 @@ pub fn elem_exp_consttime<M>( acc.as_flattened_mut() .copy_from_slice(base_cached.as_flattened()); + let n0 = m.n0(); + // Fill in entries 1, 2, 4, 8, 16. scatter_powers_of_2(table.as_mut(), acc.as_mut(), m_cached, n0, 1, cpu2)?; // Fill in entries 3, 6, 12, 24; 5, 10, 20, 30; 7, 14, 28; 9, 18; 11, 22; 13, 26; 15, 30; @@ -715,10 +743,25 @@ mod tests { .expect("valid exponent") }; let base = into_encoded(base, &m); - let actual_result = elem_exp_consttime(base, &e, &m) - .map_err(error::erase::<LimbSliceError>) - .unwrap(); - assert_elem_eq(&actual_result, &expected_result); + + let too_big = m.limbs().len() > ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS; + let actual_result = if !too_big { + elem_exp_consttime(base, &e, &m) + } else { + let actual_result = elem_exp_consttime(base.clone(), &e, &m); + // TODO: Be more specific with which error we expect? + assert!(actual_result.is_err()); + // Try again with a larger-than-normally-supported limit + elem_exp_consttime_inner::<_, { (4096 / LIMB_BITS) * STORAGE_ENTRIES }>( + base, &e, &m, + ) + }; + match actual_result { + Ok(r) => assert_elem_eq(&r, &expected_result), + Err(LimbSliceError::LenMismatch { .. }) => panic!(), + Err(LimbSliceError::TooLong { .. }) => panic!(), + Err(LimbSliceError::TooShort { .. }) => panic!(), + }; Ok(()) }, diff --git a/src/arithmetic/limbs512/mod.rs b/src/arithmetic/limbs512/mod.rs new file mode 100644 index 0000000000..122cb16651 --- /dev/null +++ b/src/arithmetic/limbs512/mod.rs @@ -0,0 +1,17 @@ +// Copyright 2025 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +mod storage; + +pub(super) use self::storage::{AlignedStorage, LIMBS_PER_CHUNK}; diff --git a/src/arithmetic/limbs512/storage.rs b/src/arithmetic/limbs512/storage.rs new file mode 100644 index 0000000000..91cf44139b --- /dev/null +++ b/src/arithmetic/limbs512/storage.rs @@ -0,0 +1,60 @@ +// Copyright 2025 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +use crate::{ + error::LenMismatchError, + limb::{Limb, LIMB_BITS}, + polyfill::slice::{self, AsChunksMut}, +}; +use core::mem::{align_of, size_of}; + +// Some x86_64 assembly is written under the assumption that some of its +// input data and/or temporary storage is aligned to `MOD_EXP_CTIME_ALIGN` +// bytes, which was/is 64 in OpenSSL. +// +// We use this in the non-X86-64 implementation of exponentiation as well, +// with the hope of converging th two implementations into one. + +#[repr(C, align(64))] +pub struct AlignedStorage<const N: usize>([Limb; N]); + +const _LIMB_SIZE_DIVIDES_ALIGNMENT: () = + assert!(align_of::<AlignedStorage<1>>() % size_of::<Limb>() == 0); + +pub const LIMBS_PER_CHUNK: usize = 512 / LIMB_BITS; + +impl<const N: usize> AlignedStorage<N> { + pub fn zeroed() -> Self { + assert_eq!(N % LIMBS_PER_CHUNK, 0); // TODO: const. + Self([0; N]) + } + + // The result will have every chunk aligned on a 64 byte boundary. + pub fn aligned_chunks_mut( + &mut self, + num_entries: usize, + chunks_per_entry: usize, + ) -> Result<AsChunksMut<Limb, LIMBS_PER_CHUNK>, LenMismatchError> { + let total_limbs = num_entries * chunks_per_entry * LIMBS_PER_CHUNK; + let len = self.0.len(); + let flattened = self + .0 + .get_mut(..total_limbs) + .ok_or_else(|| LenMismatchError::new(len))?; + match slice::as_chunks_mut(flattened) { + (chunks, []) => Ok(chunks), + (_, r) => Err(LenMismatchError::new(r.len())), + } + } +} diff --git a/src/polyfill/slice/as_chunks_mut.rs b/src/polyfill/slice/as_chunks_mut.rs index f2bb7a9de2..a4364eb868 100644 --- a/src/polyfill/slice/as_chunks_mut.rs +++ b/src/polyfill/slice/as_chunks_mut.rs @@ -40,6 +40,11 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> { self.0.as_ptr().cast() } + #[cfg(target_arch = "x86_64")] + pub fn as_ptr(&self) -> *const [T; N] { + self.0.as_ptr().cast() + } + #[cfg(target_arch = "aarch64")] pub fn as_mut_ptr(&mut self) -> *mut [T; N] { self.0.as_mut_ptr().cast() @@ -62,6 +67,13 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> { pub fn chunks_mut<const CHUNK_LEN: usize>(&mut self) -> AsChunksMutChunksMutIter<T, N> { AsChunksMutChunksMutIter(self.0.chunks_mut(CHUNK_LEN * N)) } + + #[cfg(target_arch = "x86_64")] + #[inline(always)] + pub fn split_at_mut(&mut self, mid: usize) -> (AsChunksMut<T, N>, AsChunksMut<T, N>) { + let (before, after) = self.0.split_at_mut(mid * N); + (AsChunksMut(before), AsChunksMut(after)) + } } pub struct AsChunksMutChunksMutIter<'a, T, const N: usize>(core::slice::ChunksMut<'a, T>);