Skip to content

Commit

Permalink
Merge pull request #5592 from oasisprotocol/peternose/feature/lagrange
Browse files Browse the repository at this point in the history
secret-sharing/src/vss: Add Lagrange interpolation
  • Loading branch information
peternose authored Mar 11, 2024
2 parents afd38cc + 6c8d03f commit 955969d
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 0 deletions.
1 change: 1 addition & 0 deletions .changelog/5592.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
secret-sharing/src/vss: Add Lagrange interpolation
2 changes: 2 additions & 0 deletions secret-sharing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
//!
//! - CHURP (CHUrn-Robust Proactive secret sharing)
#![feature(test)]

pub mod churp;
pub mod vss;
8 changes: 8 additions & 0 deletions secret-sharing/src/vss/lagrange/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//! Lagrange interpolation.
mod multiplier;
mod naive;
mod optimized;

// Re-exports.
pub use self::{naive::*, optimized::*};
196 changes: 196 additions & 0 deletions secret-sharing/src/vss/lagrange/multiplier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use std::ops::Mul;

/// Multiplier efficiently computes the product of all values except one.
///
/// The multiplier constructs a tree where leaf nodes represent given values,
/// and internal nodes represent the product of their children's values.
/// To obtain the product of all values except one, traverse down the tree
/// to the node containing that value and multiply the values of sibling nodes
/// encountered along the way.
pub struct Multiplier<T>
where
T: Mul<Output = T> + Clone + Default,
{
/// The root node of the tree.
root: Node<T>,
}

impl<T> Multiplier<T>
where
T: Mul<Output = T> + Clone + Default,
{
/// Constructs a new multiplier using the given values.
pub fn new(values: &[T]) -> Self {
let root = Self::create(values, true);

Self { root }
}

/// Helper function to recursively construct the tree.
fn create(values: &[T], root: bool) -> Node<T> {
match values.len() {
0 => {
// When given an empty slice, return zero, which should be the default value.
return Node::Leaf(LeafNode {
value: Default::default(),
});
}
1 => {
// Store values in the leaf nodes.
return Node::Leaf(LeafNode {
value: values[0].clone(),
});
}
_ => (),
}

let size = values.len();
let middle = size / 2;
let left = Box::new(Self::create(&values[..middle], false));
let right = Box::new(Self::create(&values[middle..], false));
let value = match root {
true => None,
false => Some(left.get_value() * right.get_value()),
};

Node::Internal(InternalNode {
value,
left,
right,
size,
})
}

/// Returns the product of all values except the one at the given index.
pub fn get_product(&self, index: usize) -> T {
self.root.get_product(index).unwrap_or_default()
}
}

/// Represents a node in the tree.
enum Node<T> {
/// Internal nodes store the product of their children's values.
Internal(InternalNode<T>),
/// Leaf nodes store given values.
Leaf(LeafNode<T>),
}

impl<T> Node<T>
where
T: Mul<Output = T> + Clone,
{
/// Returns the value stored in the node.
///
/// # Panics
///
/// This function panics if called on the root node.
fn get_value(&self) -> T {
match self {
Node::Internal(n) => n.value.clone().expect("should not be called on root node"),
Node::Leaf(n) => n.value.clone(),
}
}

/// Returns the number of leaf nodes in the subtree.
fn get_size(&self) -> usize {
match self {
Node::Internal(n) => n.size,
Node::Leaf(_) => 1,
}
}

/// Returns the product of all values stored in the subtree except
/// the one at the given index.
fn get_product(&self, index: usize) -> Option<T> {
match self {
Node::Internal(n) => {
let left_size = n.left.get_size();
match index < left_size {
true => {
if let Some(value) = n.left.get_product(index) {
Some(n.right.get_value() * value)
} else {
Some(n.right.get_value())
}
}
false => {
if let Some(value) = n.right.get_product(index - left_size) {
Some(n.left.get_value() * value)
} else {
Some(n.left.get_value())
}
}
}
}
Node::Leaf(n) => {
if index > 0 {
Some(n.value.clone())
} else {
None
}
}
}
}
}

/// Represents an internal node in the tree.
struct InternalNode<T> {
/// The product of its children's values.
///
/// Optional for the root node.
value: Option<T>,
/// The left child node.
left: Box<Node<T>>,
/// The right child node.
right: Box<Node<T>>,
/// The number of leaf nodes in the subtree.
size: usize,
}

/// Represents a leaf node in the tree.
struct LeafNode<T> {
/// The value stored in the leaf node.
value: T,
}

#[cfg(test)]
mod tests {
use super::Multiplier;

#[test]
fn test_multiplier() {
// No values.
let m = Multiplier::<usize>::new(&vec![]);
for i in 0..10 {
let product = m.get_product(i);
assert_eq!(product, 0);
}

// One value.
let values = vec![1];
let products = vec![0, 1, 1];
let m = Multiplier::new(&values);

for (i, expected) in products.into_iter().enumerate() {
let product = m.get_product(i);
assert_eq!(product, expected);
}

// Many values.
let values = vec![1, 2, 3, 4, 5];
let total = values.iter().fold(1, |acc, x| acc * x);
let products = values.iter().map(|x| total / x);
let m = Multiplier::new(&values);

for (i, expected) in products.enumerate() {
let product = m.get_product(i);
assert_eq!(product, expected);
}

// Index out of bounds.
for i in 5..10 {
let product = m.get_product(i);
assert_eq!(product, total);
}
}
}
153 changes: 153 additions & 0 deletions secret-sharing/src/vss/lagrange/naive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Lagrange Polynomials interpolation / reconstruction
use std::iter::zip;

use group::ff::PrimeField;

use crate::vss::polynomial::Polynomial;

/// Returns the Lagrange interpolation polynomial for the given set of points.
///
/// The Lagrange polynomial is defined as:
/// ```text
/// L(x) = \sum_{i=0}^n y_i * L_i(x)
/// ```
/// where `L_i(x)` represents the i-th Lagrange basis polynomial.
pub fn lagrange_naive<Fp>(xs: &[Fp], ys: &[Fp]) -> Polynomial<Fp>
where
Fp: PrimeField,
{
let ls = (0..xs.len())
.map(|i| basis_polynomial_naive(i, xs))
.collect::<Vec<_>>();

zip(ls, ys).map(|(li, &yi)| li * yi).sum()
}

/// Returns i-th Lagrange basis polynomials for the given set of x values.
///
/// The i-th Lagrange basis polynomial is defined as:
/// ```text
/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j)
/// ```
/// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`.
fn basis_polynomial_naive<Fp>(i: usize, xs: &[Fp]) -> Polynomial<Fp>
where
Fp: PrimeField,
{
let mut nom = Polynomial::with_coefficients(vec![Fp::ONE]);
let mut denom = Fp::ONE;
for j in 0..xs.len() {
if j == i {
continue;
}
nom *= Polynomial::with_coefficients(vec![xs[j], Fp::ONE.neg()]); // (x_j - x)
denom *= xs[j] - xs[i]; // (x_j - x_i)
}
let denom_inv = denom.invert().expect("values should be unique");
nom *= denom_inv; // L_i(x) = nom / denom

nom
}

#[cfg(test)]
mod tests {
extern crate test;

use self::test::Bencher;

use std::iter::zip;

use group::ff::Field;
use rand::{rngs::StdRng, RngCore, SeedableRng};

use super::{basis_polynomial_naive, lagrange_naive};

fn scalar(value: i64) -> p384::Scalar {
scalars(&vec![value])[0]
}

fn scalars(values: &[i64]) -> Vec<p384::Scalar> {
values
.iter()
.map(|&w| match w.is_negative() {
false => p384::Scalar::from_u64(w as u64),
true => p384::Scalar::from_u64(-w as u64).neg(),
})
.collect()
}

fn random_scalars(n: usize, mut rng: &mut impl RngCore) -> Vec<p384::Scalar> {
(0..n).map(|_| p384::Scalar::random(&mut rng)).collect()
}

#[test]
fn test_lagrange_naive() {
let xs = scalars(&[1, 2, 3]);
let ys = scalars(&[2, 4, 8]);
let p = lagrange_naive(&xs, &ys);

// Verify zeros.
for (x, y) in zip(xs, ys) {
assert_eq!(p.eval(&x), y);
}

// Verify degree.
assert_eq!(p.highest_degree(), 2);
}

#[test]
fn test_basis_polynomial_naive() {
let xs = scalars(&[1, 2, 3]);

for i in 0..xs.len() {
let p = basis_polynomial_naive(i, &xs);

// Verify points.
for (j, x) in xs.iter().enumerate() {
if j == i {
assert_eq!(p.eval(x), scalar(1)); // L_i(x_i) = 1
} else {
assert_eq!(p.eval(x), scalar(0)); // L_i(x_j) = 0
}
}

// Verify degree.
assert_eq!(p.highest_degree(), 2);
}
}

fn bench_lagrange_naive(b: &mut Bencher, n: usize) {
let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]);
let xs = random_scalars(n, &mut rng);
let ys = random_scalars(n, &mut rng);

b.iter(|| {
let _p = lagrange_naive(&xs, &ys);
});
}

#[bench]
fn bench_lagrange_naive_1(b: &mut Bencher) {
bench_lagrange_naive(b, 1)
}

#[bench]
fn bench_lagrange_naive_2(b: &mut Bencher) {
bench_lagrange_naive(b, 2)
}

#[bench]
fn bench_lagrange_naive_5(b: &mut Bencher) {
bench_lagrange_naive(b, 5)
}

#[bench]
fn bench_lagrange_naive_10(b: &mut Bencher) {
bench_lagrange_naive(b, 10)
}

#[bench]
fn bench_lagrange_naive_20(b: &mut Bencher) {
bench_lagrange_naive(b, 20)
}
}
Loading

0 comments on commit 955969d

Please sign in to comment.