diff --git a/.changelog/5592.feature.md b/.changelog/5592.feature.md new file mode 100644 index 00000000000..1121355ec69 --- /dev/null +++ b/.changelog/5592.feature.md @@ -0,0 +1 @@ +secret-sharing/src/vss: Add Lagrange interpolation diff --git a/secret-sharing/src/lib.rs b/secret-sharing/src/lib.rs index 575c52d7779..120f3e04dac 100644 --- a/secret-sharing/src/lib.rs +++ b/secret-sharing/src/lib.rs @@ -8,5 +8,7 @@ //! //! - CHURP (CHUrn-Robust Proactive secret sharing) +#![feature(test)] + pub mod churp; pub mod vss; diff --git a/secret-sharing/src/vss/lagrange/mod.rs b/secret-sharing/src/vss/lagrange/mod.rs new file mode 100644 index 00000000000..f4919554e1f --- /dev/null +++ b/secret-sharing/src/vss/lagrange/mod.rs @@ -0,0 +1,8 @@ +//! Lagrange interpolation. + +mod multiplier; +mod naive; +mod optimized; + +// Re-exports. +pub use self::{naive::*, optimized::*}; diff --git a/secret-sharing/src/vss/lagrange/multiplier.rs b/secret-sharing/src/vss/lagrange/multiplier.rs new file mode 100644 index 00000000000..24eb84d494c --- /dev/null +++ b/secret-sharing/src/vss/lagrange/multiplier.rs @@ -0,0 +1,196 @@ +use std::ops::Mul; + +/// Multiplier efficiently computes the product of all values except one. +/// +/// The multiplier constructs a tree where leaf nodes represent given values, +/// and internal nodes represent the product of their children's values. +/// To obtain the product of all values except one, traverse down the tree +/// to the node containing that value and multiply the values of sibling nodes +/// encountered along the way. +pub struct Multiplier +where + T: Mul + Clone + Default, +{ + /// The root node of the tree. + root: Node, +} + +impl Multiplier +where + T: Mul + Clone + Default, +{ + /// Constructs a new multiplier using the given values. + pub fn new(values: &[T]) -> Self { + let root = Self::create(values, true); + + Self { root } + } + + /// Helper function to recursively construct the tree. + fn create(values: &[T], root: bool) -> Node { + match values.len() { + 0 => { + // When given an empty slice, return zero, which should be the default value. + return Node::Leaf(LeafNode { + value: Default::default(), + }); + } + 1 => { + // Store values in the leaf nodes. + return Node::Leaf(LeafNode { + value: values[0].clone(), + }); + } + _ => (), + } + + let size = values.len(); + let middle = size / 2; + let left = Box::new(Self::create(&values[..middle], false)); + let right = Box::new(Self::create(&values[middle..], false)); + let value = match root { + true => None, + false => Some(left.get_value() * right.get_value()), + }; + + Node::Internal(InternalNode { + value, + left, + right, + size, + }) + } + + /// Returns the product of all values except the one at the given index. + pub fn get_product(&self, index: usize) -> T { + self.root.get_product(index).unwrap_or_default() + } +} + +/// Represents a node in the tree. +enum Node { + /// Internal nodes store the product of their children's values. + Internal(InternalNode), + /// Leaf nodes store given values. + Leaf(LeafNode), +} + +impl Node +where + T: Mul + Clone, +{ + /// Returns the value stored in the node. + /// + /// # Panics + /// + /// This function panics if called on the root node. + fn get_value(&self) -> T { + match self { + Node::Internal(n) => n.value.clone().expect("should not be called on root node"), + Node::Leaf(n) => n.value.clone(), + } + } + + /// Returns the number of leaf nodes in the subtree. + fn get_size(&self) -> usize { + match self { + Node::Internal(n) => n.size, + Node::Leaf(_) => 1, + } + } + + /// Returns the product of all values stored in the subtree except + /// the one at the given index. + fn get_product(&self, index: usize) -> Option { + match self { + Node::Internal(n) => { + let left_size = n.left.get_size(); + match index < left_size { + true => { + if let Some(value) = n.left.get_product(index) { + Some(n.right.get_value() * value) + } else { + Some(n.right.get_value()) + } + } + false => { + if let Some(value) = n.right.get_product(index - left_size) { + Some(n.left.get_value() * value) + } else { + Some(n.left.get_value()) + } + } + } + } + Node::Leaf(n) => { + if index > 0 { + Some(n.value.clone()) + } else { + None + } + } + } + } +} + +/// Represents an internal node in the tree. +struct InternalNode { + /// The product of its children's values. + /// + /// Optional for the root node. + value: Option, + /// The left child node. + left: Box>, + /// The right child node. + right: Box>, + /// The number of leaf nodes in the subtree. + size: usize, +} + +/// Represents a leaf node in the tree. +struct LeafNode { + /// The value stored in the leaf node. + value: T, +} + +#[cfg(test)] +mod tests { + use super::Multiplier; + + #[test] + fn test_multiplier() { + // No values. + let m = Multiplier::::new(&vec![]); + for i in 0..10 { + let product = m.get_product(i); + assert_eq!(product, 0); + } + + // One value. + let values = vec![1]; + let products = vec![0, 1, 1]; + let m = Multiplier::new(&values); + + for (i, expected) in products.into_iter().enumerate() { + let product = m.get_product(i); + assert_eq!(product, expected); + } + + // Many values. + let values = vec![1, 2, 3, 4, 5]; + let total = values.iter().fold(1, |acc, x| acc * x); + let products = values.iter().map(|x| total / x); + let m = Multiplier::new(&values); + + for (i, expected) in products.enumerate() { + let product = m.get_product(i); + assert_eq!(product, expected); + } + + // Index out of bounds. + for i in 5..10 { + let product = m.get_product(i); + assert_eq!(product, total); + } + } +} diff --git a/secret-sharing/src/vss/lagrange/naive.rs b/secret-sharing/src/vss/lagrange/naive.rs new file mode 100644 index 00000000000..1de9f8b8237 --- /dev/null +++ b/secret-sharing/src/vss/lagrange/naive.rs @@ -0,0 +1,153 @@ +// Lagrange Polynomials interpolation / reconstruction +use std::iter::zip; + +use group::ff::PrimeField; + +use crate::vss::polynomial::Polynomial; + +/// Returns the Lagrange interpolation polynomial for the given set of points. +/// +/// The Lagrange polynomial is defined as: +/// ```text +/// L(x) = \sum_{i=0}^n y_i * L_i(x) +/// ``` +/// where `L_i(x)` represents the i-th Lagrange basis polynomial. +pub fn lagrange_naive(xs: &[Fp], ys: &[Fp]) -> Polynomial +where + Fp: PrimeField, +{ + let ls = (0..xs.len()) + .map(|i| basis_polynomial_naive(i, xs)) + .collect::>(); + + zip(ls, ys).map(|(li, &yi)| li * yi).sum() +} + +/// Returns i-th Lagrange basis polynomials for the given set of x values. +/// +/// The i-th Lagrange basis polynomial is defined as: +/// ```text +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// ``` +/// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +fn basis_polynomial_naive(i: usize, xs: &[Fp]) -> Polynomial +where + Fp: PrimeField, +{ + let mut nom = Polynomial::with_coefficients(vec![Fp::ONE]); + let mut denom = Fp::ONE; + for j in 0..xs.len() { + if j == i { + continue; + } + nom *= Polynomial::with_coefficients(vec![xs[j], Fp::ONE.neg()]); // (x_j - x) + denom *= xs[j] - xs[i]; // (x_j - x_i) + } + let denom_inv = denom.invert().expect("values should be unique"); + nom *= denom_inv; // L_i(x) = nom / denom + + nom +} + +#[cfg(test)] +mod tests { + extern crate test; + + use self::test::Bencher; + + use std::iter::zip; + + use group::ff::Field; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + use super::{basis_polynomial_naive, lagrange_naive}; + + fn scalar(value: i64) -> p384::Scalar { + scalars(&vec![value])[0] + } + + fn scalars(values: &[i64]) -> Vec { + values + .iter() + .map(|&w| match w.is_negative() { + false => p384::Scalar::from_u64(w as u64), + true => p384::Scalar::from_u64(-w as u64).neg(), + }) + .collect() + } + + fn random_scalars(n: usize, mut rng: &mut impl RngCore) -> Vec { + (0..n).map(|_| p384::Scalar::random(&mut rng)).collect() + } + + #[test] + fn test_lagrange_naive() { + let xs = scalars(&[1, 2, 3]); + let ys = scalars(&[2, 4, 8]); + let p = lagrange_naive(&xs, &ys); + + // Verify zeros. + for (x, y) in zip(xs, ys) { + assert_eq!(p.eval(&x), y); + } + + // Verify degree. + assert_eq!(p.highest_degree(), 2); + } + + #[test] + fn test_basis_polynomial_naive() { + let xs = scalars(&[1, 2, 3]); + + for i in 0..xs.len() { + let p = basis_polynomial_naive(i, &xs); + + // Verify points. + for (j, x) in xs.iter().enumerate() { + if j == i { + assert_eq!(p.eval(x), scalar(1)); // L_i(x_i) = 1 + } else { + assert_eq!(p.eval(x), scalar(0)); // L_i(x_j) = 0 + } + } + + // Verify degree. + assert_eq!(p.highest_degree(), 2); + } + } + + fn bench_lagrange_naive(b: &mut Bencher, n: usize) { + let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); + let xs = random_scalars(n, &mut rng); + let ys = random_scalars(n, &mut rng); + + b.iter(|| { + let _p = lagrange_naive(&xs, &ys); + }); + } + + #[bench] + fn bench_lagrange_naive_1(b: &mut Bencher) { + bench_lagrange_naive(b, 1) + } + + #[bench] + fn bench_lagrange_naive_2(b: &mut Bencher) { + bench_lagrange_naive(b, 2) + } + + #[bench] + fn bench_lagrange_naive_5(b: &mut Bencher) { + bench_lagrange_naive(b, 5) + } + + #[bench] + fn bench_lagrange_naive_10(b: &mut Bencher) { + bench_lagrange_naive(b, 10) + } + + #[bench] + fn bench_lagrange_naive_20(b: &mut Bencher) { + bench_lagrange_naive(b, 20) + } +} diff --git a/secret-sharing/src/vss/lagrange/optimized.rs b/secret-sharing/src/vss/lagrange/optimized.rs new file mode 100644 index 00000000000..31969b3b934 --- /dev/null +++ b/secret-sharing/src/vss/lagrange/optimized.rs @@ -0,0 +1,172 @@ +use std::iter::zip; + +use group::ff::PrimeField; + +use crate::vss::polynomial::Polynomial; + +use super::multiplier::Multiplier; + +/// Returns the Lagrange interpolation polynomial for the given set of points. +/// +/// The Lagrange polynomial is defined as: +/// ```text +/// L(x) = \sum_{i=0}^n y_i * L_i(x) +/// ``` +/// where `L_i(x)` represents the i-th Lagrange basis polynomial. +pub fn lagrange(xs: &[Fp], ys: &[Fp]) -> Polynomial +where + Fp: PrimeField, +{ + let m = multiplier(xs); + let ls = (0..xs.len()) + .map(|i| basis_polynomial(i, xs, &m)) + .collect::>(); + + zip(ls, ys).map(|(li, &yi)| li * yi).sum() +} + +/// Returns i-th Lagrange basis polynomials for the given set of x values. +/// +/// The i-th Lagrange basis polynomial is defined as: +/// ```text +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// ``` +/// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +fn basis_polynomial( + i: usize, + xs: &[Fp], + multiplier: &Multiplier>, +) -> Polynomial +where + Fp: PrimeField, +{ + let mut nom = multiplier.get_product(i); + let mut denom = Fp::ONE; + for j in 0..xs.len() { + if j == i { + continue; + } + denom *= xs[j] - xs[i]; // (x_j - x_i) + } + let denom_inv = denom.invert().expect("values should be unique"); + nom *= denom_inv; // L_i(x) = nom / denom + + nom +} + +/// Creates a multiplier for the nominators in the Lagrange basis polynomials. +fn multiplier(xs: &[Fp]) -> Multiplier> +where + Fp: PrimeField, +{ + let basis: Vec<_> = xs + .iter() + .map(|x| Polynomial::with_coefficients(vec![*x, Fp::ONE.neg()])) // (x_j - x) + .collect(); + Multiplier::new(&basis) +} + +#[cfg(test)] +mod tests { + extern crate test; + + use self::test::Bencher; + + use std::iter::zip; + + use group::ff::Field; + use rand::{rngs::StdRng, RngCore, SeedableRng}; + + use super::{basis_polynomial, lagrange, multiplier}; + + fn scalar(value: i64) -> p384::Scalar { + scalars(&vec![value])[0] + } + + fn scalars(values: &[i64]) -> Vec { + values + .iter() + .map(|&w| match w.is_negative() { + false => p384::Scalar::from_u64(w as u64), + true => p384::Scalar::from_u64(-w as u64).neg(), + }) + .collect() + } + + fn random_scalars(n: usize, mut rng: &mut impl RngCore) -> Vec { + (0..n).map(|_| p384::Scalar::random(&mut rng)).collect() + } + + #[test] + fn test_lagrange() { + let xs = scalars(&[1, 2, 3]); + let ys = scalars(&[2, 4, 8]); + let p = lagrange(&xs, &ys); + + // Verify zeros. + for (x, y) in zip(xs, ys) { + assert_eq!(p.eval(&x), y); + } + + // Verify degree. + assert_eq!(p.highest_degree(), 2); + } + + #[test] + fn test_basis_polynomial() { + let xs = scalars(&[1, 2, 3]); + let m = multiplier(&xs); + + for i in 0..xs.len() { + let p = basis_polynomial(i, &xs, &m); + + // Verify points. + for (j, x) in xs.iter().enumerate() { + if j == i { + assert_eq!(p.eval(x), scalar(1)); // L_i(x_i) = 1 + } else { + assert_eq!(p.eval(x), scalar(0)); // L_i(x_j) = 0 + } + } + + // Verify degree. + assert_eq!(p.highest_degree(), 2); + } + } + + fn bench_lagrange(b: &mut Bencher, n: usize) { + let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); + + let xs = random_scalars(n, &mut rng); + let ys = random_scalars(n, &mut rng); + + b.iter(|| { + let _p = lagrange(&xs, &ys); + }); + } + + #[bench] + fn bench_lagrange_1(b: &mut Bencher) { + bench_lagrange(b, 1) + } + + #[bench] + fn bench_lagrange_2(b: &mut Bencher) { + bench_lagrange(b, 2) + } + + #[bench] + fn bench_lagrange_5(b: &mut Bencher) { + bench_lagrange(b, 5) + } + + #[bench] + fn bench_lagrange_10(b: &mut Bencher) { + bench_lagrange(b, 10) + } + + #[bench] + fn bench_lagrange_20(b: &mut Bencher) { + bench_lagrange(b, 20) + } +} diff --git a/secret-sharing/src/vss/mod.rs b/secret-sharing/src/vss/mod.rs index d8ff880f48e..8cdfef67d74 100644 --- a/secret-sharing/src/vss/mod.rs +++ b/secret-sharing/src/vss/mod.rs @@ -1,5 +1,6 @@ //! Verifiable secret sharing. pub mod arith; +pub mod lagrange; pub mod matrix; pub mod polynomial; diff --git a/secret-sharing/src/vss/polynomial/univariate.rs b/secret-sharing/src/vss/polynomial/univariate.rs index 7340a6803bc..61a48054d75 100644 --- a/secret-sharing/src/vss/polynomial/univariate.rs +++ b/secret-sharing/src/vss/polynomial/univariate.rs @@ -14,6 +14,14 @@ use crate::vss::arith::powers; /// ```text /// A(x) = \sum_{i=0}^{deg_x} a_i x^i /// ``` +/// +/// The constant zero polynomial is represented by a vector with one zero +/// element, rather than by an empty vector. +/// +/// Trailing zeros are never trimmed to ensure that all polynomials of the same +/// degree are consistently represented by vectors of the same size, resulting +/// in encodings of equal length. If you wish to remove them, consider using +/// the `trim` method after each operation. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Polynomial { a: Vec, @@ -53,6 +61,31 @@ where Self { a } } + /// Returns the highest of the degrees of the polynomial's monomials with + /// non-zero coefficients. + pub fn degree(&self) -> usize { + let mut deg = self.a.len().saturating_sub(1); + for ai in self.a.iter().rev() { + if ai.is_zero().into() { + deg = deg.saturating_sub(1); + } + } + + deg + } + + /// Returns the highest of the degrees of the polynomial's monomials. + pub fn highest_degree(&self) -> usize { + self.a.len() - 1 + } + + /// Removes trailing zeros. + pub fn trim(&mut self) { + while self.a.len() > 1 && self.a[self.a.len() - 1].is_zero().into() { + _ = self.a.pop(); + } + } + /// Returns the byte representation of the polynomial. pub fn to_bytes(&self) -> Vec { let cap = Self::byte_size(self.a.len()); @@ -114,6 +147,15 @@ where } } +impl Default for Polynomial +where + Fp: PrimeField, +{ + fn default() -> Self { + Self::zero(0) + } +} + impl Add for Polynomial where Fp: PrimeField, @@ -307,6 +349,52 @@ mod tests { assert_eq!(p.a, scalars(&[1, 2, 3])); } + #[test] + fn test_degree() { + let p = Polynomial::::with_coefficients(vec![]); + assert_eq!(p.degree(), 0); + assert_eq!(p.highest_degree(), 0); + + let p = Polynomial::::with_coefficients(scalars(&[0])); + assert_eq!(p.degree(), 0); + assert_eq!(p.highest_degree(), 0); + + let p = Polynomial::::with_coefficients(scalars(&[1])); + assert_eq!(p.degree(), 0); + assert_eq!(p.highest_degree(), 0); + + let p = Polynomial::::with_coefficients(scalars(&[0, 0])); + assert_eq!(p.degree(), 0); + assert_eq!(p.highest_degree(), 1); + + let p = Polynomial::::with_coefficients(scalars(&[1, 2, 3])); + assert_eq!(p.degree(), 2); + assert_eq!(p.highest_degree(), 2); + + let p = Polynomial::::with_coefficients(scalars(&[1, 2, 3, 0, 0])); + assert_eq!(p.degree(), 2); + assert_eq!(p.highest_degree(), 4); + } + + #[test] + fn test_trim() { + let mut p = Polynomial::::with_coefficients(scalars(&[0])); + p.trim(); + assert_eq!(p.a, scalars(&[0])); + + let mut p = Polynomial::::with_coefficients(scalars(&[0, 0])); + p.trim(); + assert_eq!(p.a, scalars(&[0])); + + let mut p = Polynomial::::with_coefficients(scalars(&[1, 2, 3])); + p.trim(); + assert_eq!(p.a, scalars(&[1, 2, 3])); + + let mut p = Polynomial::::with_coefficients(scalars(&[1, 2, 3, 0, 0])); + p.trim(); + assert_eq!(p.a, scalars(&[1, 2, 3])); + } + #[test] fn test_serialization() { let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]);