From 8041b7e4f7dc616a5bf1c1a38079de40a64221de Mon Sep 17 00:00:00 2001
From: Brian Smith <brian@briansmith.org>
Date: Mon, 27 Jan 2025 13:53:32 -0800
Subject: [PATCH] arithmetic: Avoid heap & simplify alignment logic in
 `elem_exp_consttime`.

Avoid allocating on the heap. Let the compiler do the alignment
instead of manually aligning the start of the table.
---
 src/arithmetic.rs                   |   1 +
 src/arithmetic/bigint.rs            | 175 +++++++++++++++++-----------
 src/arithmetic/limbs512/mod.rs      |  17 +++
 src/arithmetic/limbs512/storage.rs  |  60 ++++++++++
 src/polyfill/slice/as_chunks_mut.rs |  12 ++
 5 files changed, 199 insertions(+), 66 deletions(-)
 create mode 100644 src/arithmetic/limbs512/mod.rs
 create mode 100644 src/arithmetic/limbs512/storage.rs

diff --git a/src/arithmetic.rs b/src/arithmetic.rs
index f810741f09..3242dc41a9 100644
--- a/src/arithmetic.rs
+++ b/src/arithmetic.rs
@@ -26,6 +26,7 @@ mod constant;
 pub mod bigint;
 
 pub(crate) mod inout;
+mod limbs512;
 pub mod montgomery;
 
 mod n0;
diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs
index f948a23cbd..13335d758e 100644
--- a/src/arithmetic/bigint.rs
+++ b/src/arithmetic/bigint.rs
@@ -42,14 +42,14 @@ pub(crate) use self::{
     modulusvalue::OwnedModulusValue,
     private_exponent::PrivateExponent,
 };
-use super::{inout::AliasingSlices3, montgomery::*, LimbSliceError, MAX_LIMBS};
+use super::{inout::AliasingSlices3, limbs512, montgomery::*, LimbSliceError, MAX_LIMBS};
 use crate::{
     bits::BitLength,
     c,
     error::{self, LenMismatchError},
     limb::{self, Limb, LIMB_BITS},
+    polyfill::slice::{self, AsChunks},
 };
-use alloc::vec;
 use core::{
     marker::PhantomData,
     num::{NonZeroU64, NonZeroUsize},
@@ -410,20 +410,57 @@ pub(crate) fn elem_exp_vartime<M>(
     acc
 }
 
-#[cfg(not(target_arch = "x86_64"))]
 pub fn elem_exp_consttime<M>(
     base: Elem<M, R>,
     exponent: &PrivateExponent,
     m: &Modulus<M>,
 ) -> Result<Elem<M, Unencoded>, LimbSliceError> {
-    use crate::{bssl, limb::Window};
+    // `elem_exp_consttime_inner` is parameterized on `STORAGE_LIMBS` only so
+    // we can run tests with larger-than-supported-in-operation test vectors.
+    elem_exp_consttime_inner::<M, { ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS * STORAGE_ENTRIES }>(
+        base, exponent, m,
+    )
+}
 
-    const WINDOW_BITS: usize = 5;
-    const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
+// The maximum modulus size supported for `elem_exp_consttime` in normal
+// operation.
+const ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: usize = 2048 / LIMB_BITS;
+const _LIMBS_PER_CHUNK_DIVIDES_ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS: () =
+    assert!(ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS % limbs512::LIMBS_PER_CHUNK == 0);
+const WINDOW_BITS: u32 = 5;
+const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
+const STORAGE_ENTRIES: usize = TABLE_ENTRIES + if cfg!(target_arch = "x86_64") { 3 } else { 0 };
+
+#[cfg(not(target_arch = "x86_64"))]
+fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>(
+    base: Elem<M, R>,
+    exponent: &PrivateExponent,
+    m: &Modulus<M>,
+) -> Result<Elem<M, Unencoded>, LimbSliceError> {
+    use crate::{bssl, limb::Window};
 
     let num_limbs = m.limbs().len();
+    let m_chunked: AsChunks<Limb, { limbs512::LIMBS_PER_CHUNK }> = match slice::as_chunks(m.limbs())
+    {
+        (m, []) => m,
+        _ => {
+            return Err(LimbSliceError::len_mismatch(LenMismatchError::new(
+                num_limbs,
+            )))
+        }
+    };
+    let cpe = m_chunked.len(); // 512-bit chunks per entry.
+
+    // This code doesn't have the strict alignment requirements that the x86_64
+    // version does, but uses the same aligned storage for convenience.
+    assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const`
+    let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed();
+    let mut table = table
+        .aligned_chunks_mut(TABLE_ENTRIES, cpe)
+        .map_err(LimbSliceError::len_mismatch)?;
 
-    let mut table = vec![0; TABLE_ENTRIES * num_limbs];
+    // TODO: Rewrite the below in terms of `AsChunks`.
+    let table = table.as_flattened_mut();
 
     fn gather<M>(table: &[Limb], acc: &mut Elem<M, R>, i: Window) {
         prefixed_extern! {
@@ -463,9 +500,9 @@ pub fn elem_exp_consttime<M>(
     }
 
     // table[0] = base**0 (i.e. 1).
-    m.oneR(entry_mut(&mut table, 0, num_limbs));
+    m.oneR(entry_mut(table, 0, num_limbs));
 
-    entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
+    entry_mut(table, 1, num_limbs).copy_from_slice(&base.limbs);
     for i in 2..TABLE_ENTRIES {
         let (src1, src2) = if i % 2 == 0 {
             (i / 2, i / 2)
@@ -497,7 +534,7 @@ pub fn elem_exp_consttime<M>(
 }
 
 #[cfg(target_arch = "x86_64")]
-pub fn elem_exp_consttime<M>(
+fn elem_exp_consttime_inner<M, const STORAGE_LIMBS: usize>(
     base: Elem<M, R>,
     exponent: &PrivateExponent,
     m: &Modulus<M>,
@@ -508,8 +545,8 @@ pub fn elem_exp_consttime<M>(
             intel::{Adx, Bmi2},
             GetFeature as _,
         },
-        limb::LIMB_BYTES,
-        polyfill::slice::{self, AsChunks, AsChunksMut},
+        limb::{LeakyWindow, Window},
+        polyfill::slice::AsChunksMut,
     };
 
     let cpu2 = m.cpu_features().get_feature();
@@ -517,62 +554,51 @@ pub fn elem_exp_consttime<M>(
 
     // The x86_64 assembly was written under the assumption that the input data
     // is aligned to `MOD_EXP_CTIME_ALIGN` bytes, which was/is 64 in OpenSSL.
+    // Subsequently, it was changed such that, according to BoringSSL, they
+    // only require 16 byte alignment. We enforce the old, stronger, alignment
+    // unless/until we can see a benefit to reducing it.
+    //
     // Similarly, OpenSSL uses the x86_64 assembly functions by giving it only
-    // inputs `tmp`, `am`, and `np` that immediately follow the table. All the
-    // awkwardness here stems from trying to use the assembly code like OpenSSL
-    // does.
-
-    use crate::limb::{LeakyWindow, Window};
-
-    const WINDOW_BITS: usize = 5;
-    const TABLE_ENTRIES: usize = 1 << WINDOW_BITS;
-
-    let num_limbs = m.limbs().len();
-
-    const ALIGNMENT: usize = 64;
-    assert_eq!(ALIGNMENT % LIMB_BYTES, 0);
-    let mut table = vec![0; ((TABLE_ENTRIES + 3) * num_limbs) + ALIGNMENT];
-    let (table, state) = {
-        let misalignment = (table.as_ptr() as usize) % ALIGNMENT;
-        let table = &mut table[((ALIGNMENT - misalignment) / LIMB_BYTES)..];
-        assert_eq!((table.as_ptr() as usize) % ALIGNMENT, 0);
-        table.split_at_mut(TABLE_ENTRIES * num_limbs)
+    // inputs `tmp`, `am`, and `np` that immediately follow the table.
+    // According to BoringSSL, in older versions of the OpenSSL code, this
+    // extra space was required for memory safety because the assembly code
+    // would over-read the table; according to BoringSSL, this is no longer the
+    // case. Regardless, the upstream code also contained comments implying
+    // that this was also important for performance. For now, we do as OpenSSL
+    // did/does.
+    const MOD_EXP_CTIME_ALIGN: usize = 64;
+    // Required by
+    const _TABLE_ENTRIES_IS_32: () = assert!(TABLE_ENTRIES == 32);
+    const _STORAGE_ENTRIES_HAS_3_EXTRA: () = assert!(STORAGE_ENTRIES == TABLE_ENTRIES + 3);
+
+    let m_original: AsChunks<Limb, 8> = match slice::as_chunks(m.limbs()) {
+        (m, []) => m,
+        _ => return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8))),
     };
+    let cpe = m_original.len(); // 512-bit chunks per entry.
 
-    // These are named `(tmp, am, np)` in BoringSSL.
-    let (acc, base_cached, m_cached): (&mut [Limb], &[Limb], &[Limb]) = {
-        let (acc, rest) = state.split_at_mut(num_limbs);
-        let (base_cached, rest) = rest.split_at_mut(num_limbs);
-
-        // Upstream, the input `base` is not Montgomery-encoded, so they compute a
-        // Montgomery-encoded copy and store it here.
-        base_cached.copy_from_slice(&base.limbs);
+    assert!(STORAGE_LIMBS % (STORAGE_ENTRIES * limbs512::LIMBS_PER_CHUNK) == 0); // TODO: `const`
+    let mut table = limbs512::AlignedStorage::<STORAGE_LIMBS>::zeroed();
+    let mut table = table
+        .aligned_chunks_mut(STORAGE_ENTRIES, cpe)
+        .map_err(LimbSliceError::len_mismatch)?;
+    let (mut table, mut state) = table.split_at_mut(TABLE_ENTRIES * cpe);
+    assert_eq!((table.as_ptr() as usize) % MOD_EXP_CTIME_ALIGN, 0);
 
-        let m_cached = &mut rest[..num_limbs];
-        // "To improve cache locality" according to upstream.
-        m_cached.copy_from_slice(m.limbs());
+    // These are named `(tmp, am, np)` in BoringSSL.
+    let (mut acc, mut rest) = state.split_at_mut(cpe);
+    let (mut base_cached, mut m_cached) = rest.split_at_mut(cpe);
 
-        (acc, base_cached, m_cached)
-    };
+    // Upstream, the input `base` is not Montgomery-encoded, so they compute a
+    // Montgomery-encoded copy and store it here.
+    base_cached.as_flattened_mut().copy_from_slice(&base.limbs);
+    let base_cached = base_cached.as_ref();
 
-    let n0 = m.n0();
-
-    // TODO: Move the use of `Chunks`/`ChunksMut` up into the signature of the
-    // function so this conversion isn't necessary.
-    let (mut table, mut acc, base_cached, m_cached) = match (
-        slice::as_chunks_mut(table),
-        slice::as_chunks_mut(acc),
-        slice::as_chunks(base_cached),
-        slice::as_chunks(m_cached),
-    ) {
-        ((table, []), (acc, []), (base_cached, []), (m_cached, [])) => {
-            (table, acc, base_cached, m_cached)
-        }
-        _ => {
-            // XXX: Not the best error to return
-            return Err(LimbSliceError::len_mismatch(LenMismatchError::new(8)));
-        }
-    };
+    // "To improve cache locality" according to upstream.
+    m_cached
+        .as_flattened_mut()
+        .copy_from_slice(m_original.as_flattened());
+    let m_cached = m_cached.as_ref();
 
     // Fill in all the powers of 2 of `acc` into the table using only squaring and without any
     // gathering, storing the last calculated power into `acc`.
@@ -605,6 +631,8 @@ pub fn elem_exp_consttime<M>(
     acc.as_flattened_mut()
         .copy_from_slice(base_cached.as_flattened());
 
+    let n0 = m.n0();
+
     // Fill in entries 1, 2, 4, 8, 16.
     scatter_powers_of_2(table.as_mut(), acc.as_mut(), m_cached, n0, 1, cpu2)?;
     // Fill in entries 3, 6, 12, 24; 5, 10, 20, 30; 7, 14, 28; 9, 18; 11, 22; 13, 26; 15, 30;
@@ -715,10 +743,25 @@ mod tests {
                         .expect("valid exponent")
                 };
                 let base = into_encoded(base, &m);
-                let actual_result = elem_exp_consttime(base, &e, &m)
-                    .map_err(error::erase::<LimbSliceError>)
-                    .unwrap();
-                assert_elem_eq(&actual_result, &expected_result);
+
+                let too_big = m.limbs().len() > ELEM_EXP_CONSTTIME_MAX_MODULUS_LIMBS;
+                let actual_result = if !too_big {
+                    elem_exp_consttime(base, &e, &m)
+                } else {
+                    let actual_result = elem_exp_consttime(base.clone(), &e, &m);
+                    // TODO: Be more specific with which error we expect?
+                    assert!(actual_result.is_err());
+                    // Try again with a larger-than-normally-supported limit
+                    elem_exp_consttime_inner::<_, { (4096 / LIMB_BITS) * STORAGE_ENTRIES }>(
+                        base, &e, &m,
+                    )
+                };
+                match actual_result {
+                    Ok(r) => assert_elem_eq(&r, &expected_result),
+                    Err(LimbSliceError::LenMismatch { .. }) => panic!(),
+                    Err(LimbSliceError::TooLong { .. }) => panic!(),
+                    Err(LimbSliceError::TooShort { .. }) => panic!(),
+                };
 
                 Ok(())
             },
diff --git a/src/arithmetic/limbs512/mod.rs b/src/arithmetic/limbs512/mod.rs
new file mode 100644
index 0000000000..122cb16651
--- /dev/null
+++ b/src/arithmetic/limbs512/mod.rs
@@ -0,0 +1,17 @@
+// Copyright 2025 Brian Smith.
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
+// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+mod storage;
+
+pub(super) use self::storage::{AlignedStorage, LIMBS_PER_CHUNK};
diff --git a/src/arithmetic/limbs512/storage.rs b/src/arithmetic/limbs512/storage.rs
new file mode 100644
index 0000000000..91cf44139b
--- /dev/null
+++ b/src/arithmetic/limbs512/storage.rs
@@ -0,0 +1,60 @@
+// Copyright 2025 Brian Smith.
+//
+// Permission to use, copy, modify, and/or distribute this software for any
+// purpose with or without fee is hereby granted, provided that the above
+// copyright notice and this permission notice appear in all copies.
+//
+// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
+// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
+// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+use crate::{
+    error::LenMismatchError,
+    limb::{Limb, LIMB_BITS},
+    polyfill::slice::{self, AsChunksMut},
+};
+use core::mem::{align_of, size_of};
+
+// Some x86_64 assembly is written under the assumption that some of its
+// input data and/or temporary storage is aligned to `MOD_EXP_CTIME_ALIGN`
+// bytes, which was/is 64 in OpenSSL.
+//
+// We use this in the non-X86-64 implementation of exponentiation as well,
+// with the hope of converging th two implementations into one.
+
+#[repr(C, align(64))]
+pub struct AlignedStorage<const N: usize>([Limb; N]);
+
+const _LIMB_SIZE_DIVIDES_ALIGNMENT: () =
+    assert!(align_of::<AlignedStorage<1>>() % size_of::<Limb>() == 0);
+
+pub const LIMBS_PER_CHUNK: usize = 512 / LIMB_BITS;
+
+impl<const N: usize> AlignedStorage<N> {
+    pub fn zeroed() -> Self {
+        assert_eq!(N % LIMBS_PER_CHUNK, 0); // TODO: const.
+        Self([0; N])
+    }
+
+    // The result will have every chunk aligned on a 64 byte boundary.
+    pub fn aligned_chunks_mut(
+        &mut self,
+        num_entries: usize,
+        chunks_per_entry: usize,
+    ) -> Result<AsChunksMut<Limb, LIMBS_PER_CHUNK>, LenMismatchError> {
+        let total_limbs = num_entries * chunks_per_entry * LIMBS_PER_CHUNK;
+        let len = self.0.len();
+        let flattened = self
+            .0
+            .get_mut(..total_limbs)
+            .ok_or_else(|| LenMismatchError::new(len))?;
+        match slice::as_chunks_mut(flattened) {
+            (chunks, []) => Ok(chunks),
+            (_, r) => Err(LenMismatchError::new(r.len())),
+        }
+    }
+}
diff --git a/src/polyfill/slice/as_chunks_mut.rs b/src/polyfill/slice/as_chunks_mut.rs
index f2bb7a9de2..a4364eb868 100644
--- a/src/polyfill/slice/as_chunks_mut.rs
+++ b/src/polyfill/slice/as_chunks_mut.rs
@@ -40,6 +40,11 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> {
         self.0.as_ptr().cast()
     }
 
+    #[cfg(target_arch = "x86_64")]
+    pub fn as_ptr(&self) -> *const [T; N] {
+        self.0.as_ptr().cast()
+    }
+
     #[cfg(target_arch = "aarch64")]
     pub fn as_mut_ptr(&mut self) -> *mut [T; N] {
         self.0.as_mut_ptr().cast()
@@ -62,6 +67,13 @@ impl<T, const N: usize> AsChunksMut<'_, T, N> {
     pub fn chunks_mut<const CHUNK_LEN: usize>(&mut self) -> AsChunksMutChunksMutIter<T, N> {
         AsChunksMutChunksMutIter(self.0.chunks_mut(CHUNK_LEN * N))
     }
+
+    #[cfg(target_arch = "x86_64")]
+    #[inline(always)]
+    pub fn split_at_mut(&mut self, mid: usize) -> (AsChunksMut<T, N>, AsChunksMut<T, N>) {
+        let (before, after) = self.0.split_at_mut(mid * N);
+        (AsChunksMut(before), AsChunksMut(after))
+    }
 }
 
 pub struct AsChunksMutChunksMutIter<'a, T, const N: usize>(core::slice::ChunksMut<'a, T>);