Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized linear combination of points #380

Merged
merged 2 commits into from
Jul 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion k256/bench/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use criterion::{
use hex_literal::hex;
use k256::{
elliptic_curve::{generic_array::arr, group::ff::PrimeField},
ProjectivePoint, Scalar,
lincomb, ProjectivePoint, Scalar,
};

fn test_scalar_x() -> Scalar {
Expand Down Expand Up @@ -34,9 +34,18 @@ fn bench_point_mul<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
group.bench_function("point-scalar mul", |b| b.iter(|| &p * &s));
}

fn bench_point_lincomb<'a, M: Measurement>(group: &mut BenchmarkGroup<'a, M>) {
let p = ProjectivePoint::generator();
let m = hex!("AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522");
let s = Scalar::from_repr(m.into()).unwrap();
group.bench_function("lincomb via mul+add", |b| b.iter(|| &p * &s + &p * &s));
group.bench_function("lincomb()", |b| b.iter(|| lincomb(&p, &s, &p, &s)));
}

fn bench_high_level(c: &mut Criterion) {
let mut group = c.benchmark_group("high-level operations");
bench_point_mul(&mut group);
bench_point_lincomb(&mut group);
group.finish();
}

Expand Down
1 change: 1 addition & 0 deletions k256/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod scalar;
mod util;

pub use field::FieldElement;
pub use mul::lincomb;

use affine::AffinePoint;
use projective::ProjectivePoint;
Expand Down
211 changes: 168 additions & 43 deletions k256/src/arithmetic/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ use core::ops::{Mul, MulAssign};
use elliptic_curve::subtle::{Choice, ConditionallySelectable, ConstantTimeEq};

/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]`
#[derive(Copy, Clone, Default)]
struct LookupTable([ProjectivePoint; 8]);

impl From<&ProjectivePoint> for LookupTable {
Expand Down Expand Up @@ -147,94 +148,218 @@ fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) {
(r1, r2)
}

/// Returns `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
/// and `-8 <= a_j <= 7`.
/// Assumes `x < 2^128`.
fn to_radix_16_half(x: &Scalar) -> [i8; 33] {
// `x` can have up to 256 bits, so we need an additional byte to store the carry.
let mut output = [0i8; 33];

// Step 1: change radix.
// Convert from radix 256 (bytes) to radix 16 (nibbles)
let bytes = x.to_bytes();
for i in 0..16 {
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
}
// This needs to be an object to have Default implemented for it
// (required because it's used in static_map later)
// Otherwise we could just have a function returning an array.
#[derive(Copy, Clone)]
struct Radix16Decomposition([i8; 33]);

impl Radix16Decomposition {
/// Returns an object containing a decomposition
/// `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
/// and `-8 <= a_j <= 7`.
/// Assumes `x < 2^128`.
fn new(x: &Scalar) -> Self {
debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);

// The resulting decomposition can be negative, so, despite the limit on `x`,
// it can have up to 256 bits, and we need an additional byte to store the carry.
let mut output = [0i8; 33];

// Step 1: change radix.
// Convert from radix 256 (bytes) to radix 16 (nibbles)
let bytes = x.to_bytes();
for i in 0..16 {
output[2 * i] = (bytes[31 - i] & 0xf) as i8;
output[2 * i + 1] = ((bytes[31 - i] >> 4) & 0xf) as i8;
}

debug_assert!((x >> 128).is_zero().unwrap_u8() == 1);
// Step 2: recenter coefficients from [0,16) to [-8,8)
for i in 0..32 {
let carry = (output[i] + 8) >> 4;
output[i] -= carry << 4;
output[i + 1] += carry;
}

// Step 2: recenter coefficients from [0,16) to [-8,8)
for i in 0..32 {
let carry = (output[i] + 8) >> 4;
output[i] -= carry << 4;
output[i + 1] += carry;
Self(output)
}

output
}

fn mul_windowed(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
let (r1, r2) = decompose_scalar(k);
let x_beta = x.endomorphism();
impl Default for Radix16Decomposition {
fn default() -> Self {
Self([0i8; 33])
}
}

let r1_sign = r1.is_high();
let r1_c = Scalar::conditional_select(&r1, &-r1, r1_sign);
let r2_sign = r2.is_high();
let r2_c = Scalar::conditional_select(&r2, &-r2, r2_sign);
/// Maps an array `x` to an array using the predicate `f`.
/// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
/// Consequently, since we cannot have an uninitialized array (without `unsafe`),
/// a default value needs to be provided.
fn static_map<T: Copy, V: Copy, const N: usize>(
f: impl Fn(T) -> V,
x: &[T; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i]);
}
res
}

let table1 = LookupTable::from(&ProjectivePoint::conditional_select(x, &-x, r1_sign));
let table2 = LookupTable::from(&ProjectivePoint::conditional_select(
&x_beta, &-x_beta, r2_sign,
));
/// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
fn static_zip_map<T: Copy, S: Copy, V: Copy, const N: usize>(
f: impl Fn(T, S) -> V,
x: &[T; N],
y: &[S; N],
default: V,
) -> [V; N] {
let mut res = [default; N];
for i in 0..N {
res[i] = f(x[i], y[i]);
}
res
}

let digits1 = to_radix_16_half(&r1_c);
let digits2 = to_radix_16_half(&r2_c);
/// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
#[inline(always)]
fn lincomb_generic<const N: usize>(xs: &[ProjectivePoint; N], ks: &[Scalar; N]) -> ProjectivePoint {
let rs = static_map(
|k| decompose_scalar(&k),
ks,
(Scalar::default(), Scalar::default()),
);
let r1s = static_map(|(r1, _r2)| r1, &rs, Scalar::default());
let r2s = static_map(|(_r1, r2)| r2, &rs, Scalar::default());

let xs_beta = static_map(|x| x.endomorphism(), xs, ProjectivePoint::default());

let r1_signs = static_map(|r| r.is_high(), &r1s, Choice::from(0u8));
let r2_signs = static_map(|r| r.is_high(), &r2s, Choice::from(0u8));

let r1s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r1s,
&r1_signs,
Scalar::default(),
);
let r2s_c = static_zip_map(
|r, r_sign| Scalar::conditional_select(&r, &-r, r_sign),
&r2s,
&r2_signs,
Scalar::default(),
);

let tables1 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs,
&r1_signs,
LookupTable::default(),
);
let tables2 = static_zip_map(
|x, r_sign| LookupTable::from(&ProjectivePoint::conditional_select(&x, &-x, r_sign)),
&xs_beta,
&r2_signs,
LookupTable::default(),
);

let digits1 = static_map(
|r| Radix16Decomposition::new(&r),
&r1s_c,
Radix16Decomposition::default(),
);
let digits2 = static_map(
|r| Radix16Decomposition::new(&r),
&r2s_c,
Radix16Decomposition::default(),
);

let mut acc = ProjectivePoint::identity();
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[32]);
acc += &tables2[component].select(digits2[component].0[32]);
}

let mut acc = table1.select(digits1[32]) + table2.select(digits2[32]);
for i in (0..32).rev() {
for _j in 0..4 {
acc = acc.double();
}

acc += &table1.select(digits1[i]);
acc += &table2.select(digits2[i]);
for component in 0..N {
acc += &tables1[component].select(digits1[component].0[i]);
acc += &tables2[component].select(digits2[component].0[i]);
}
}
acc
}

#[inline(always)]
fn mul(x: &ProjectivePoint, k: &Scalar) -> ProjectivePoint {
lincomb_generic(&[*x], &[*k])
}

/// Calculates `x * k + y * l`.
pub fn lincomb(
x: &ProjectivePoint,
k: &Scalar,
y: &ProjectivePoint,
l: &Scalar,
) -> ProjectivePoint {
lincomb_generic(&[*x, *y], &[*k, *l])
}

impl Mul<Scalar> for ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: Scalar) -> ProjectivePoint {
mul_windowed(&self, &other)
mul(&self, &other)
}
}

impl Mul<&Scalar> for &ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: &Scalar) -> ProjectivePoint {
mul_windowed(self, other)
mul(self, other)
}
}

impl Mul<&Scalar> for ProjectivePoint {
type Output = ProjectivePoint;

fn mul(self, other: &Scalar) -> ProjectivePoint {
mul_windowed(&self, other)
mul(&self, other)
}
}

impl MulAssign<Scalar> for ProjectivePoint {
fn mul_assign(&mut self, rhs: Scalar) {
*self = mul_windowed(self, &rhs);
*self = mul(self, &rhs);
}
}

impl MulAssign<&Scalar> for ProjectivePoint {
fn mul_assign(&mut self, rhs: &Scalar) {
*self = mul_windowed(self, rhs);
*self = mul(self, rhs);
}
}

#[cfg(test)]
mod tests {
use super::lincomb;
use crate::arithmetic::{ProjectivePoint, Scalar};
use elliptic_curve::rand_core::OsRng;
use elliptic_curve::{Field, Group};

#[test]
fn test_lincomb() {
let x = ProjectivePoint::random(&mut OsRng);
let y = ProjectivePoint::random(&mut OsRng);
let k = Scalar::random(&mut OsRng);
let l = Scalar::random(&mut OsRng);

let reference = &x * &k + &y * &l;
let test = lincomb(&x, &k, &y, &l);
assert_eq!(reference, test);
}
}
4 changes: 2 additions & 2 deletions k256/src/ecdsa/recoverable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ use crate::{
consts::U32, generic_array::GenericArray, ops::Invert, subtle::Choice,
weierstrass::DecompressPoint,
},
AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
lincomb, AffinePoint, FieldBytes, NonZeroScalar, ProjectivePoint, Scalar,
};

#[cfg(feature = "keccak256")]
Expand Down Expand Up @@ -185,7 +185,7 @@ impl Signature {
let r_inv = r.invert().unwrap();
let u1 = -(r_inv * z);
let u2 = r_inv * *s;
let pk = ((ProjectivePoint::generator() * u1) + (R * u2)).to_affine();
let pk = lincomb(&ProjectivePoint::generator(), &u1, &R, &u2).to_affine();

// TODO(tarcieri): ensure the signature verifies?
Ok(VerifyingKey::from(&pk))
Expand Down
14 changes: 10 additions & 4 deletions k256/src/ecdsa/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

use super::{recoverable, Error, Signature};
use crate::{
AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar, Secp256k1,
lincomb, AffinePoint, CompressedPoint, EncodedPoint, ProjectivePoint, PublicKey, Scalar,
Secp256k1,
};
use core::convert::TryFrom;
use ecdsa_core::{hazmat::VerifyPrimitive, signature};
Expand Down Expand Up @@ -90,9 +91,14 @@ impl VerifyPrimitive<Secp256k1> for AffinePoint {
let u1 = z * &s_inv;
let u2 = *r * s_inv;

let x = ((ProjectivePoint::generator() * u1) + (ProjectivePoint::from(*self) * u2))
.to_affine()
.x;
let x = lincomb(
&ProjectivePoint::generator(),
&u1,
&ProjectivePoint::from(*self),
&u2,
)
.to_affine()
.x;

if Scalar::from_bytes_reduced(&x.to_bytes()).eq(&r) {
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion k256/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub mod test_vectors;
pub use elliptic_curve::{self, bigint::U256};

#[cfg(feature = "arithmetic")]
pub use arithmetic::{affine::AffinePoint, projective::ProjectivePoint, scalar::Scalar};
pub use arithmetic::{affine::AffinePoint, lincomb, projective::ProjectivePoint, scalar::Scalar};

#[cfg(feature = "expose-field")]
pub use arithmetic::FieldElement;
Expand Down