From a85e930c32beb00170f5d6d262ee03c2d62e6c76 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 16 Jul 2021 16:12:04 -0700 Subject: [PATCH 1/2] Generalize point-scalar multiplication to allow optimized linear combinations --- k256/src/arithmetic/mul.rs | 181 ++++++++++++++++++++++++++++--------- 1 file changed, 138 insertions(+), 43 deletions(-) diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index 2e07ea1d..d1125bed 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -70,6 +70,7 @@ use core::ops::{Mul, MulAssign}; use elliptic_curve::subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; /// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]` +#[derive(Copy, Clone, Default)] struct LookupTable([ProjectivePoint; 8]); impl From<&ProjectivePoint> for LookupTable { @@ -147,67 +148,161 @@ fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) { (r1, r2) } -/// Returns `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`, -/// and `-8 <= a_j <= 7`. -/// Assumes `x < 2^128`. -fn to_radix_16_half(x: &Scalar) -> [i8; 33] { - // `x` can have up to 256 bits, so we need an additional byte to store the carry. - let mut output = [0i8; 33]; - - // Step 1: change radix. - // Convert from radix 256 (bytes) to radix 16 (nibbles) - let bytes = x.to_bytes(); - for i in 0..16 { - output[2 * i] = (bytes[31 - i] & 0xf) as i8; - output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8; - } +// This needs to be an object to have Default implemented for it +// (required because it's used in static_map later) +// Otherwise we could just have a function returning an array. +#[derive(Copy, Clone)] +struct Radix16Decomposition([i8; 33]); + +impl Radix16Decomposition { + /// Returns an object containing a decomposition + /// `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`, + /// and `-8 <= a_j <= 7`. + /// Assumes `x < 2^128`. + fn new(x: &Scalar) -> Self { + debug_assert!((x >> 128).is_zero().unwrap_u8() == 1); + + // The resulting decomposition can be negative, so, despite the limit on `x`, + // it can have up to 256 bits, and we need an additional byte to store the carry. + let mut output = [0i8; 33]; + + // Step 1: change radix. + // Convert from radix 256 (bytes) to radix 16 (nibbles) + let bytes = x.to_bytes(); + for i in 0..16 { + output[2 * i] = (bytes[31 - i] & 0xf) as i8; + output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8; + } - debug_assert!((x >> 128).is_zero().unwrap_u8() == 1); + // Step 2: recenter coefficients from [0,16) to [-8,8) + for i in 0..32 { + let carry = (output[i] + 8) >> 4; + output[i] -= carry << 4; + output[i + 1] += carry; + } - // Step 2: recenter coefficients from [0,16) to [-8,8) - for i in 0..32 { - let carry = (output[i] + 8) >> 4; - output[i] -= carry << 4; - output[i + 1] += carry; + Self(output) } - - output } -fn mul_windowed(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint { - let (r1, r2) = decompose_scalar(k); - let x_beta = x.endomorphism(); +impl Default for Radix16Decomposition { + fn default() -> Self { + Self([0i8; 33]) + } +} - let r1_sign = r1.is_high(); - let r1_c = Scalar::conditional_select(&r1, &-r1, r1_sign); - let r2_sign = r2.is_high(); - let r2_c = Scalar::conditional_select(&r2, &-r2, r2_sign); +/// Maps an array `x` to an array using the predicate `f`. +/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays. +/// Consequently, since we cannot have an uninitialized array (without `unsafe`), +/// a default value needs to be provided. +fn static_map( + f: impl Fn(T) -> V, + x: &[T; N], + default: V, +) -> [V; N] { + let mut res = [default; N]; + for i in 0..N { + res[i] = f(x[i]); + } + res +} - let table1 = LookupTable::from(&ProjectivePoint::conditional_select(x, &-x, r1_sign)); - let table2 = LookupTable::from(&ProjectivePoint::conditional_select( - &x_beta, &-x_beta, r2_sign, - )); +/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments. +fn static_zip_map( + f: impl Fn(T, S) -> V, + x: &[T; N], + y: &[S; N], + default: V, +) -> [V; N] { + let mut res = [default; N]; + for i in 0..N { + res[i] = f(x[i], y[i]); + } + res +} - let digits1 = to_radix_16_half(&r1_c); - let digits2 = to_radix_16_half(&r2_c); +/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N` +#[inline(always)] +fn lincomb_generic(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint { + let rs = static_map( + |k| decompose_scalar(&k), + ks, + (Scalar::default(), Scalar::default()), + ); + let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default()); + let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default()); + + let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default()); + + let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8)); + let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8)); + + let r1s_c = static_zip_map( + |r, r_sign| Scalar::conditional_select(&r, &-r, r_sign), + &r1s, + &r1_signs, + Scalar::default(), + ); + let r2s_c = static_zip_map( + |r, r_sign| Scalar::conditional_select(&r, &-r, r_sign), + &r2s, + &r2_signs, + Scalar::default(), + ); + + let tables1 = static_zip_map( + |x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)), + &xs, + &r1_signs, + LookupTable::default(), + ); + let tables2 = static_zip_map( + |x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)), + &xs_beta, + &r2_signs, + LookupTable::default(), + ); + + let digits1 = static_map( + |r| Radix16Decomposition::new(&r), + &r1s_c, + Radix16Decomposition::default(), + ); + let digits2 = static_map( + |r| Radix16Decomposition::new(&r), + &r2s_c, + Radix16Decomposition::default(), + ); + + let mut acc = ProjectivePoint::identity(); + for component in 0..N { + acc += &tables1[component].select(digits1[component].0[32]); + acc += &tables2[component].select(digits2[component].0[32]); + } - let mut acc = table1.select(digits1[32]) + table2.select(digits2[32]); for i in (0..32).rev() { for _j in 0..4 { acc = acc.double(); } - acc += &table1.select(digits1[i]); - acc += &table2.select(digits2[i]); + for component in 0..N { + acc += &tables1[component].select(digits1[component].0[i]); + acc += &tables2[component].select(digits2[component].0[i]); + } } acc } +#[inline(always)] +fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint { + lincomb_generic(&[*x], &[*k]) +} + impl Mul for ProjectivePoint { type Output = ProjectivePoint; fn mul(self, other: Scalar) -> ProjectivePoint { - mul_windowed(&self, &other) + mul(&self, &other) } } @@ -215,7 +310,7 @@ impl Mul<&Scalar> for &ProjectivePoint { type Output = ProjectivePoint; fn mul(self, other: &Scalar) -> ProjectivePoint { - mul_windowed(self, other) + mul(self, other) } } @@ -223,18 +318,18 @@ impl Mul<&Scalar> for ProjectivePoint { type Output = ProjectivePoint; fn mul(self, other: &Scalar) -> ProjectivePoint { - mul_windowed(&self, other) + mul(&self, other) } } impl MulAssign for ProjectivePoint { fn mul_assign(&mut self, rhs: Scalar) { - *self = mul_windowed(self, &rhs); + *self = mul(self, &rhs); } } impl MulAssign<&Scalar> for ProjectivePoint { fn mul_assign(&mut self, rhs: &Scalar) { - *self = mul_windowed(self, rhs); + *self = mul(self, rhs); } } From 5c14bea456532065332d021970e4dd09ea601908 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 16 Jul 2021 16:12:57 -0700 Subject: [PATCH 2/2] Add lincomb() as an alias for a 2-point linear combination --- k256/bench/scalar.rs | 11 ++++++++++- k256/src/arithmetic.rs | 1 + k256/src/arithmetic/mul.rs | 30 ++++++++++++++++++++++++++++++ k256/src/ecdsa/recoverable.rs | 4 ++-- k256/src/ecdsa/verify.rs | 14 ++++++++++---- k256/src/lib.rs | 2 +- 6 files changed, 54 insertions(+), 8 deletions(-) diff --git a/k256/bench/scalar.rs b/k256/bench/scalar.rs index 6be3901d..091b691e 100644 --- a/k256/bench/scalar.rs +++ b/k256/bench/scalar.rs @@ -6,7 +6,7 @@ use criterion::{ use hex_literal::hex; use k256::{ elliptic_curve::{generic_array::arr, group::ff::PrimeField}, - ProjectivePoint, Scalar, + lincomb, ProjectivePoint, Scalar, }; fn test_scalar_x() -> Scalar { @@ -34,9 +34,18 @@ fn bench_point_mul<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { group.bench_function("point-scalar mul", |b| b.iter(|| &p * &s)); } +fn bench_point_lincomb<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) { + let p = ProjectivePoint::generator(); + let m = hex!("AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522"); + let s = Scalar::from_repr(m.into()).unwrap(); + group.bench_function("lincomb via mul+add", |b| b.iter(|| &p * &s + &p * &s)); + group.bench_function("lincomb()", |b| b.iter(|| lincomb(&p, &s, &p, &s))); +} + fn bench_high_level(c: &mut Criterion) { let mut group = c.benchmark_group("high-level operations"); bench_point_mul(&mut group); + bench_point_lincomb(&mut group); group.finish(); } diff --git a/k256/src/arithmetic.rs b/k256/src/arithmetic.rs index 2c18d648..ea876610 100644 --- a/k256/src/arithmetic.rs +++ b/k256/src/arithmetic.rs @@ -8,6 +8,7 @@ pub(crate) mod scalar; mod util; pub use field::FieldElement; +pub use mul::lincomb; use affine::AffinePoint; use projective::ProjectivePoint; diff --git a/k256/src/arithmetic/mul.rs b/k256/src/arithmetic/mul.rs index d1125bed..f6275e7e 100644 --- a/k256/src/arithmetic/mul.rs +++ b/k256/src/arithmetic/mul.rs @@ -298,6 +298,16 @@ fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint { lincomb_generic(&[*x], &[*k]) } +/// Calculates `x * k + y * l`. +pub fn lincomb( + x: &ProjectivePoint, + k: &Scalar, + y: &ProjectivePoint, + l: &Scalar, +) -> ProjectivePoint { + lincomb_generic(&[*x, *y], &[*k, *l]) +} + impl Mul for ProjectivePoint { type Output = ProjectivePoint; @@ -333,3 +343,23 @@ impl MulAssign<&Scalar> for ProjectivePoint { *self = mul(self, rhs); } } + +#[cfg(test)] +mod tests { + use super::lincomb; + use crate::arithmetic::{ProjectivePoint, Scalar}; + use elliptic_curve::rand_core::OsRng; + use elliptic_curve::{Field, Group}; + + #[test] + fn test_lincomb() { + let x = ProjectivePoint::random(&mut OsRng); + let y = ProjectivePoint::random(&mut OsRng); + let k = Scalar::random(&mut OsRng); + let l = Scalar::random(&mut OsRng); + + let reference = &x * &k + &y * &l; + let test = lincomb(&x, &k, &y, &l); + assert_eq!(reference, test); + } +} diff --git a/k256/src/ecdsa/recoverable.rs b/k256/src/ecdsa/recoverable.rs index dd8dd7e5..36f05147 100644 --- a/k256/src/ecdsa/recoverable.rs +++ b/k256/src/ecdsa/recoverable.rs @@ -51,7 +51,7 @@ use crate::{ consts::U32, generic_array::GenericArray, ops::Invert, subtle::Choice, weierstrass::DecompressPoint, }, - AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar, + lincomb, AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar, }; #[cfg(feature = "keccak256")] @@ -185,7 +185,7 @@ impl Signature { let r_inv = r.invert().unwrap(); let u1 = -(r_inv * z); let u2 = r_inv * *s; - let pk = ((ProjectivePoint::generator() * u1) + (R * u2)).to_affine(); + let pk = lincomb(&ProjectivePoint::generator(), &u1, &R, &u2).to_affine(); // TODO(tarcieri): ensure the signature verifies? Ok(VerifyingKey::from(&pk)) diff --git a/k256/src/ecdsa/verify.rs b/k256/src/ecdsa/verify.rs index 57c9ef90..a5bb5624 100644 --- a/k256/src/ecdsa/verify.rs +++ b/k256/src/ecdsa/verify.rs @@ -2,7 +2,8 @@ use super::{recoverable, Error, Signature}; use crate::{ - AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar, Secp256k1, + lincomb, AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar, + Secp256k1, }; use core::convert::TryFrom; use ecdsa_core::{hazmat::VerifyPrimitive, signature}; @@ -90,9 +91,14 @@ impl VerifyPrimitive for AffinePoint { let u1 = z * &s_inv; let u2 = *r * s_inv; - let x = ((ProjectivePoint::generator() * u1) + (ProjectivePoint::from(*self) * u2)) - .to_affine() - .x; + let x = lincomb( + &ProjectivePoint::generator(), + &u1, + &ProjectivePoint::from(*self), + &u2, + ) + .to_affine() + .x; if Scalar::from_bytes_reduced(&x.to_bytes()).eq(&r) { Ok(()) diff --git a/k256/src/lib.rs b/k256/src/lib.rs index 2bc85323..b45af6d8 100644 --- a/k256/src/lib.rs +++ b/k256/src/lib.rs @@ -67,7 +67,7 @@ pub mod test_vectors; pub use elliptic_curve::{self, bigint::U256}; #[cfg(feature = "arithmetic")] -pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar}; +pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar}; #[cfg(feature = "expose-field")] pub use arithmetic::FieldElement;