Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

arithmetic: Pass inputs to limbs_mul_mont as slices. #2247

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ mod constant;
#[cfg(feature = "alloc")]
pub mod bigint;

mod inout;
pub mod montgomery;

mod n0;

#[allow(dead_code)]
const BIGINT_MODULUS_MAX_LIMBS: usize = 8192 / crate::limb::LIMB_BITS;

pub use constant::limbs_from_hex;
pub use self::{constant::limbs_from_hex, inout::InOut};
30 changes: 24 additions & 6 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ pub(crate) use self::{
modulusvalue::OwnedModulusValue,
private_exponent::PrivateExponent,
};
use super::{montgomery::*, InOut};
use crate::{
arithmetic::montgomery::*,
bits::BitLength,
c, error,
limb::{self, Limb, LIMB_BITS},
Expand Down Expand Up @@ -99,7 +99,13 @@ fn from_montgomery_amm<M>(limbs: BoxedLimbs<M>, m: &Modulus<M>) -> Elem<M, Unenc
let mut one = [0; MODULUS_MAX_LIMBS];
one[0] = 1;
let one = &one[..m.limbs().len()];
limbs_mont_mul(&mut limbs, one, m.limbs(), m.n0(), m.cpu_features());
limbs_mul_mont(
InOut::InPlace(&mut limbs),
one,
m.limbs(),
m.n0(),
m.cpu_features(),
);
Elem {
limbs,
encoding: PhantomData,
Expand Down Expand Up @@ -144,7 +150,13 @@ pub fn elem_mul<M, AF, BF>(
where
(AF, BF): ProductEncoding,
{
limbs_mont_mul(&mut b.limbs, &a.limbs, m.limbs(), m.n0(), m.cpu_features());
limbs_mul_mont(
InOut::InPlace(&mut b.limbs),
&a.limbs,
m.limbs(),
m.n0(),
m.cpu_features(),
);
Elem {
limbs: b.limbs,
encoding: PhantomData,
Expand Down Expand Up @@ -205,7 +217,7 @@ fn elem_squared<M, E>(
where
(E, E): ProductEncoding,
{
limbs_mont_square(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features());
limbs_square_mont(&mut a.limbs, m.limbs(), m.n0(), m.cpu_features());
Elem {
limbs: a.limbs,
encoding: PhantomData,
Expand Down Expand Up @@ -467,7 +479,13 @@ pub fn elem_exp_consttime<M>(
let src1 = entry(previous, src1, num_limbs);
let src2 = entry(previous, src2, num_limbs);
let dst = entry_mut(rest, 0, num_limbs);
limbs_mont_product(dst, src1, src2, m.limbs(), m.n0(), m.cpu_features());
limbs_mul_mont(
InOut::Disjoint(dst, src1),
src2,
m.limbs(),
m.n0(),
m.cpu_features(),
);
}

let tmp = m.zero();
Expand Down Expand Up @@ -638,7 +656,7 @@ pub fn elem_exp_consttime<M>(
if i >= TABLE_ENTRIES as LeakyWindow {
break;
}
limbs_mont_square(acc, m_cached, n0, cpu_features);
limbs_square_mont(acc, m_cached, n0, cpu_features);
}
}

Expand Down
19 changes: 19 additions & 0 deletions src/arithmetic/inout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2025 Brian Smith.
//
// Permission to use, copy, modify, and/or distribute this software for any
// purpose with or without fee is hereby granted, provided that the above
// copyright notice and this permission notice appear in all copies.
//
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub enum InOut<'io, T: ?Sized> {
InPlace(&'io mut T),
#[cfg_attr(target_arch = "x86_64", allow(dead_code))]
Disjoint(&'io mut T, &'io T),
}
99 changes: 26 additions & 73 deletions src/arithmetic/montgomery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

pub use super::n0::N0;
pub use super::{n0::N0, InOut};
use crate::cpu;

// Indicates that the element is not encoded; there is no *R* factor
Expand Down Expand Up @@ -113,15 +113,24 @@ impl ProductEncoding for (RRR, RInverse) {
use crate::{bssl, c, limb::Limb};

#[inline(always)]
unsafe fn mul_mont(
r: *mut Limb,
a: *const Limb,
b: *const Limb,
n: *const Limb,
n0: &N0,
num_limbs: c::size_t,
_: cpu::Features,
) {
pub(super) fn limbs_mul_mont(ra: InOut<[Limb]>, b: &[Limb], n: &[Limb], n0: &N0, _: cpu::Features) {
// XXX/TODO: All the `debug_assert_eq!` length checking needs to be
// replaced with enforcement that happens regardless of debug mode.
let (r, a) = match ra {
InOut::InPlace(r) => {
debug_assert_eq!(r.len(), n.len());
(r.as_mut_ptr(), r.as_ptr())
}
InOut::Disjoint(r, a) => {
debug_assert_eq!(r.len(), n.len());
debug_assert_eq!(a.len(), n.len());
(r.as_mut_ptr(), a.as_ptr())
}
};
debug_assert_eq!(b.len(), n.len());
let b = b.as_ptr();
let num_limbs = n.len();
let n = n.as_ptr();
unsafe { bn_mul_mont(r, a, b, n, n0, num_limbs) }
}

Expand Down Expand Up @@ -249,71 +258,15 @@ prefixed_extern! {
);
}

/// r *= a
pub(super) fn limbs_mont_mul(
r: &mut [Limb],
a: &[Limb],
m: &[Limb],
n0: &N0,
cpu_features: cpu::Features,
) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
unsafe {
mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
a.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
}

/// r = a * b
#[cfg(not(target_arch = "x86_64"))]
pub(super) fn limbs_mont_product(
r: &mut [Limb],
a: &[Limb],
b: &[Limb],
m: &[Limb],
n0: &N0,
cpu_features: cpu::Features,
) {
debug_assert_eq!(r.len(), m.len());
debug_assert_eq!(a.len(), m.len());
debug_assert_eq!(b.len(), m.len());

unsafe {
mul_mont(
r.as_mut_ptr(),
a.as_ptr(),
b.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
}

/// r = r**2
pub(super) fn limbs_mont_square(r: &mut [Limb], m: &[Limb], n0: &N0, cpu_features: cpu::Features) {
debug_assert_eq!(r.len(), m.len());
unsafe {
mul_mont(
r.as_mut_ptr(),
r.as_ptr(),
r.as_ptr(),
m.as_ptr(),
n0,
r.len(),
cpu_features,
)
}
pub(super) fn limbs_square_mont(r: &mut [Limb], n: &[Limb], n0: &N0, _cpu: cpu::Features) {
debug_assert_eq!(r.len(), n.len());
let r = r.as_mut_ptr();
let num_limbs = n.len();
let n = n.as_ptr();
unsafe { bn_mul_mont(r, r, r, n, n0, num_limbs) }
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down