From 666b60d35873288b69ffaf18dc7e44574b69d1f3 Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Fri, 17 Jan 2025 20:38:14 -0800 Subject: [PATCH 1/2] arithmetic: Rename `limbs_mont_square` to `limbs_square_mont`. Be more consistent with the non-squaring Montgomery multiplication functions. --- src/arithmetic/bigint.rs | 4 ++-- src/arithmetic/montgomery.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 33269a0b46..f3d867550b 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -205,7 +205,7 @@ fn elem_squared( where (E, E): ProductEncoding, { - limbs_mont_square(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features()); + limbs_square_mont(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features()); Elem { limbs: a.limbs, encoding: PhantomData, @@ -638,7 +638,7 @@ pub fn elem_exp_consttime( if i >= TABLE_ENTRIES as LeakyWindow { break; } - limbs_mont_square(acc, m_cached, n0, cpu_features); + limbs_square_mont(acc, m_cached, n0, cpu_features); } } diff --git a/src/arithmetic/montgomery.rs b/src/arithmetic/montgomery.rs index 1a695f19d0..5f5a14a847 100644 --- a/src/arithmetic/montgomery.rs +++ b/src/arithmetic/montgomery.rs @@ -300,7 +300,7 @@ pub(super) fn limbs_mont_product( } /// r = r**2 -pub(super) fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) { +pub(super) fn limbs_square_mont(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) { debug_assert_eq!(r.len(), m.len()); unsafe { mul_mont( From 8f4d5f0945ed21f3f1f688c0aa6a10b2c11a91da Mon Sep 17 00:00:00 2001 From: Brian Smith Date: Fri, 17 Jan 2025 19:55:57 -0800 Subject: [PATCH 2/2] arithmetic: Pass inputs to `limbs_mul_mont` as slices. Remove some gratuitous unsafety. In theory there are many patterns of aliasing the arguments when calling `bn_mul_mont`, but in practice we only have three: 1. r *= a (mod n) 2. r = a * b (mod n) 3. r = r * r (mod n) Rename `mul_mont` to `limbs_mul_mont` and have it handle both #1 & #2. Refactor it so that its arguments are slices. Remove the `limbs_mont_mul` and `limbs_mont_product` wrappers around `limbs_mul_mont` in favor of exposing `limbs_mul_mont` directly. Change `limbs_square_mont` to call `bn_mul_mont` directly. Although we could have `mul_mont` handle this by making the new `InOut` type more complicated, but we'd just end up undoing this when `bn_mul_mont` is split into separate squaring and non-squaring functions later. --- src/arithmetic.rs | 3 +- src/arithmetic/bigint.rs | 26 ++++++++-- src/arithmetic/inout.rs | 19 +++++++ src/arithmetic/montgomery.rs | 99 ++++++++++-------------------------- 4 files changed, 69 insertions(+), 78 deletions(-) create mode 100644 src/arithmetic/inout.rs diff --git a/src/arithmetic.rs b/src/arithmetic.rs index e3dc6c4489..b240d9df2a 100644 --- a/src/arithmetic.rs +++ b/src/arithmetic.rs @@ -17,6 +17,7 @@ mod constant; #[cfg(feature = "alloc")] pub mod bigint; +mod inout; pub mod montgomery; mod n0; @@ -24,4 +25,4 @@ mod n0; #[allow(dead_code)] const BIGINT_MODULUS_MAX_LIMBS: usize = 8192 / crate::limb::LIMB_BITS; -pub use constant::limbs_from_hex; +pub use self::{constant::limbs_from_hex, inout::InOut}; diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index f3d867550b..e0be21c8b4 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -42,8 +42,8 @@ pub(crate) use self::{ modulusvalue::OwnedModulusValue, private_exponent::PrivateExponent, }; +use super::{montgomery::*, InOut}; use crate::{ - arithmetic::montgomery::*, bits::BitLength, c, error, limb::{self, Limb, LIMB_BITS}, @@ -99,7 +99,13 @@ fn from_montgomery_amm(limbs: BoxedLimbs, m: &Modulus) -> Elem( where (AF, BF): ProductEncoding, { - limbs_mont_mul(&mut b.limbs, &a.limbs, m.limbs(), m.n0(), m.cpu_features()); + limbs_mul_mont( + InOut::InPlace(&mut b.limbs), + &a.limbs, + m.limbs(), + m.n0(), + m.cpu_features(), + ); Elem { limbs: b.limbs, encoding: PhantomData, @@ -467,7 +479,13 @@ pub fn elem_exp_consttime( let src1 = entry(previous, src1, num_limbs); let src2 = entry(previous, src2, num_limbs); let dst = entry_mut(rest, 0, num_limbs); - limbs_mont_product(dst, src1, src2, m.limbs(), m.n0(), m.cpu_features()); + limbs_mul_mont( + InOut::Disjoint(dst, src1), + src2, + m.limbs(), + m.n0(), + m.cpu_features(), + ); } let tmp = m.zero(); diff --git a/src/arithmetic/inout.rs b/src/arithmetic/inout.rs new file mode 100644 index 0000000000..daafb4f5a9 --- /dev/null +++ b/src/arithmetic/inout.rs @@ -0,0 +1,19 @@ +// Copyright 2025 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +pub enum InOut<'io, T: ?Sized> { + InPlace(&'io mut T), + #[cfg_attr(target_arch = "x86_64", allow(dead_code))] + Disjoint(&'io mut T, &'io T), +} diff --git a/src/arithmetic/montgomery.rs b/src/arithmetic/montgomery.rs index 5f5a14a847..b58b4ef090 100644 --- a/src/arithmetic/montgomery.rs +++ b/src/arithmetic/montgomery.rs @@ -12,7 +12,7 @@ // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -pub use super::n0::N0; +pub use super::{n0::N0, InOut}; use crate::cpu; // Indicates that the element is not encoded; there is no *R* factor @@ -113,15 +113,24 @@ impl ProductEncoding for (RRR, RInverse) { use crate::{bssl, c, limb::Limb}; #[inline(always)] -unsafe fn mul_mont( - r: *mut Limb, - a: *const Limb, - b: *const Limb, - n: *const Limb, - n0: &N0, - num_limbs: c::size_t, - _: cpu::Features, -) { +pub(super) fn limbs_mul_mont(ra: InOut<[Limb]>, b: &[Limb], n: &[Limb], n0: &N0, _: cpu::Features) { + // XXX/TODO: All the `debug_assert_eq!` length checking needs to be + // replaced with enforcement that happens regardless of debug mode. + let (r, a) = match ra { + InOut::InPlace(r) => { + debug_assert_eq!(r.len(), n.len()); + (r.as_mut_ptr(), r.as_ptr()) + } + InOut::Disjoint(r, a) => { + debug_assert_eq!(r.len(), n.len()); + debug_assert_eq!(a.len(), n.len()); + (r.as_mut_ptr(), a.as_ptr()) + } + }; + debug_assert_eq!(b.len(), n.len()); + let b = b.as_ptr(); + let num_limbs = n.len(); + let n = n.as_ptr(); unsafe { bn_mul_mont(r, a, b, n, n0, num_limbs) } } @@ -249,71 +258,15 @@ prefixed_extern! { ); } -/// r *= a -pub(super) fn limbs_mont_mul( - r: &mut [Limb], - a: &[Limb], - m: &[Limb], - n0: &N0, - cpu_features: cpu::Features, -) { - debug_assert_eq!(r.len(), m.len()); - debug_assert_eq!(a.len(), m.len()); - unsafe { - mul_mont( - r.as_mut_ptr(), - r.as_ptr(), - a.as_ptr(), - m.as_ptr(), - n0, - r.len(), - cpu_features, - ) - } -} - -/// r = a * b -#[cfg(not(target_arch = "x86_64"))] -pub(super) fn limbs_mont_product( - r: &mut [Limb], - a: &[Limb], - b: &[Limb], - m: &[Limb], - n0: &N0, - cpu_features: cpu::Features, -) { - debug_assert_eq!(r.len(), m.len()); - debug_assert_eq!(a.len(), m.len()); - debug_assert_eq!(b.len(), m.len()); - - unsafe { - mul_mont( - r.as_mut_ptr(), - a.as_ptr(), - b.as_ptr(), - m.as_ptr(), - n0, - r.len(), - cpu_features, - ) - } -} - /// r = r**2 -pub(super) fn limbs_square_mont(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) { - debug_assert_eq!(r.len(), m.len()); - unsafe { - mul_mont( - r.as_mut_ptr(), - r.as_ptr(), - r.as_ptr(), - m.as_ptr(), - n0, - r.len(), - cpu_features, - ) - } +pub(super) fn limbs_square_mont(r: &mut [Limb], n: &[Limb], n0: &N0, _cpu: cpu::Features) { + debug_assert_eq!(r.len(), n.len()); + let r = r.as_mut_ptr(); + let num_limbs = n.len(); + let n = n.as_ptr(); + unsafe { bn_mul_mont(r, r, r, n, n0, num_limbs) } } + #[cfg(test)] mod tests { use super::*;