-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5592 from oasisprotocol/peternose/feature/lagrange
secret-sharing/src/vss: Add Lagrange interpolation
- Loading branch information
Showing
8 changed files
with
621 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
secret-sharing/src/vss: Add Lagrange interpolation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,5 +8,7 @@ | |
//! | ||
//! - CHURP (CHUrn-Robust Proactive secret sharing) | ||
#![feature(test)] | ||
|
||
pub mod churp; | ||
pub mod vss; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
//! Lagrange interpolation. | ||
mod multiplier; | ||
mod naive; | ||
mod optimized; | ||
|
||
// Re-exports. | ||
pub use self::{naive::*, optimized::*}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
use std::ops::Mul; | ||
|
||
/// Multiplier efficiently computes the product of all values except one. | ||
/// | ||
/// The multiplier constructs a tree where leaf nodes represent given values, | ||
/// and internal nodes represent the product of their children's values. | ||
/// To obtain the product of all values except one, traverse down the tree | ||
/// to the node containing that value and multiply the values of sibling nodes | ||
/// encountered along the way. | ||
pub struct Multiplier<T> | ||
where | ||
T: Mul<Output = T> + Clone + Default, | ||
{ | ||
/// The root node of the tree. | ||
root: Node<T>, | ||
} | ||
|
||
impl<T> Multiplier<T> | ||
where | ||
T: Mul<Output = T> + Clone + Default, | ||
{ | ||
/// Constructs a new multiplier using the given values. | ||
pub fn new(values: &[T]) -> Self { | ||
let root = Self::create(values, true); | ||
|
||
Self { root } | ||
} | ||
|
||
/// Helper function to recursively construct the tree. | ||
fn create(values: &[T], root: bool) -> Node<T> { | ||
match values.len() { | ||
0 => { | ||
// When given an empty slice, return zero, which should be the default value. | ||
return Node::Leaf(LeafNode { | ||
value: Default::default(), | ||
}); | ||
} | ||
1 => { | ||
// Store values in the leaf nodes. | ||
return Node::Leaf(LeafNode { | ||
value: values[0].clone(), | ||
}); | ||
} | ||
_ => (), | ||
} | ||
|
||
let size = values.len(); | ||
let middle = size / 2; | ||
let left = Box::new(Self::create(&values[..middle], false)); | ||
let right = Box::new(Self::create(&values[middle..], false)); | ||
let value = match root { | ||
true => None, | ||
false => Some(left.get_value() * right.get_value()), | ||
}; | ||
|
||
Node::Internal(InternalNode { | ||
value, | ||
left, | ||
right, | ||
size, | ||
}) | ||
} | ||
|
||
/// Returns the product of all values except the one at the given index. | ||
pub fn get_product(&self, index: usize) -> T { | ||
self.root.get_product(index).unwrap_or_default() | ||
} | ||
} | ||
|
||
/// Represents a node in the tree. | ||
enum Node<T> { | ||
/// Internal nodes store the product of their children's values. | ||
Internal(InternalNode<T>), | ||
/// Leaf nodes store given values. | ||
Leaf(LeafNode<T>), | ||
} | ||
|
||
impl<T> Node<T> | ||
where | ||
T: Mul<Output = T> + Clone, | ||
{ | ||
/// Returns the value stored in the node. | ||
/// | ||
/// # Panics | ||
/// | ||
/// This function panics if called on the root node. | ||
fn get_value(&self) -> T { | ||
match self { | ||
Node::Internal(n) => n.value.clone().expect("should not be called on root node"), | ||
Node::Leaf(n) => n.value.clone(), | ||
} | ||
} | ||
|
||
/// Returns the number of leaf nodes in the subtree. | ||
fn get_size(&self) -> usize { | ||
match self { | ||
Node::Internal(n) => n.size, | ||
Node::Leaf(_) => 1, | ||
} | ||
} | ||
|
||
/// Returns the product of all values stored in the subtree except | ||
/// the one at the given index. | ||
fn get_product(&self, index: usize) -> Option<T> { | ||
match self { | ||
Node::Internal(n) => { | ||
let left_size = n.left.get_size(); | ||
match index < left_size { | ||
true => { | ||
if let Some(value) = n.left.get_product(index) { | ||
Some(n.right.get_value() * value) | ||
} else { | ||
Some(n.right.get_value()) | ||
} | ||
} | ||
false => { | ||
if let Some(value) = n.right.get_product(index - left_size) { | ||
Some(n.left.get_value() * value) | ||
} else { | ||
Some(n.left.get_value()) | ||
} | ||
} | ||
} | ||
} | ||
Node::Leaf(n) => { | ||
if index > 0 { | ||
Some(n.value.clone()) | ||
} else { | ||
None | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Represents an internal node in the tree. | ||
struct InternalNode<T> { | ||
/// The product of its children's values. | ||
/// | ||
/// Optional for the root node. | ||
value: Option<T>, | ||
/// The left child node. | ||
left: Box<Node<T>>, | ||
/// The right child node. | ||
right: Box<Node<T>>, | ||
/// The number of leaf nodes in the subtree. | ||
size: usize, | ||
} | ||
|
||
/// Represents a leaf node in the tree. | ||
struct LeafNode<T> { | ||
/// The value stored in the leaf node. | ||
value: T, | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::Multiplier; | ||
|
||
#[test] | ||
fn test_multiplier() { | ||
// No values. | ||
let m = Multiplier::<usize>::new(&vec![]); | ||
for i in 0..10 { | ||
let product = m.get_product(i); | ||
assert_eq!(product, 0); | ||
} | ||
|
||
// One value. | ||
let values = vec![1]; | ||
let products = vec![0, 1, 1]; | ||
let m = Multiplier::new(&values); | ||
|
||
for (i, expected) in products.into_iter().enumerate() { | ||
let product = m.get_product(i); | ||
assert_eq!(product, expected); | ||
} | ||
|
||
// Many values. | ||
let values = vec![1, 2, 3, 4, 5]; | ||
let total = values.iter().fold(1, |acc, x| acc * x); | ||
let products = values.iter().map(|x| total / x); | ||
let m = Multiplier::new(&values); | ||
|
||
for (i, expected) in products.enumerate() { | ||
let product = m.get_product(i); | ||
assert_eq!(product, expected); | ||
} | ||
|
||
// Index out of bounds. | ||
for i in 5..10 { | ||
let product = m.get_product(i); | ||
assert_eq!(product, total); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
// Lagrange Polynomials interpolation / reconstruction | ||
use std::iter::zip; | ||
|
||
use group::ff::PrimeField; | ||
|
||
use crate::vss::polynomial::Polynomial; | ||
|
||
/// Returns the Lagrange interpolation polynomial for the given set of points. | ||
/// | ||
/// The Lagrange polynomial is defined as: | ||
/// ```text | ||
/// L(x) = \sum_{i=0}^n y_i * L_i(x) | ||
/// ``` | ||
/// where `L_i(x)` represents the i-th Lagrange basis polynomial. | ||
pub fn lagrange_naive<Fp>(xs: &[Fp], ys: &[Fp]) -> Polynomial<Fp> | ||
where | ||
Fp: PrimeField, | ||
{ | ||
let ls = (0..xs.len()) | ||
.map(|i| basis_polynomial_naive(i, xs)) | ||
.collect::<Vec<_>>(); | ||
|
||
zip(ls, ys).map(|(li, &yi)| li * yi).sum() | ||
} | ||
|
||
/// Returns i-th Lagrange basis polynomials for the given set of x values. | ||
/// | ||
/// The i-th Lagrange basis polynomial is defined as: | ||
/// ```text | ||
/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) | ||
/// ``` | ||
/// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. | ||
fn basis_polynomial_naive<Fp>(i: usize, xs: &[Fp]) -> Polynomial<Fp> | ||
where | ||
Fp: PrimeField, | ||
{ | ||
let mut nom = Polynomial::with_coefficients(vec![Fp::ONE]); | ||
let mut denom = Fp::ONE; | ||
for j in 0..xs.len() { | ||
if j == i { | ||
continue; | ||
} | ||
nom *= Polynomial::with_coefficients(vec![xs[j], Fp::ONE.neg()]); // (x_j - x) | ||
denom *= xs[j] - xs[i]; // (x_j - x_i) | ||
} | ||
let denom_inv = denom.invert().expect("values should be unique"); | ||
nom *= denom_inv; // L_i(x) = nom / denom | ||
|
||
nom | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
extern crate test; | ||
|
||
use self::test::Bencher; | ||
|
||
use std::iter::zip; | ||
|
||
use group::ff::Field; | ||
use rand::{rngs::StdRng, RngCore, SeedableRng}; | ||
|
||
use super::{basis_polynomial_naive, lagrange_naive}; | ||
|
||
fn scalar(value: i64) -> p384::Scalar { | ||
scalars(&vec![value])[0] | ||
} | ||
|
||
fn scalars(values: &[i64]) -> Vec<p384::Scalar> { | ||
values | ||
.iter() | ||
.map(|&w| match w.is_negative() { | ||
false => p384::Scalar::from_u64(w as u64), | ||
true => p384::Scalar::from_u64(-w as u64).neg(), | ||
}) | ||
.collect() | ||
} | ||
|
||
fn random_scalars(n: usize, mut rng: &mut impl RngCore) -> Vec<p384::Scalar> { | ||
(0..n).map(|_| p384::Scalar::random(&mut rng)).collect() | ||
} | ||
|
||
#[test] | ||
fn test_lagrange_naive() { | ||
let xs = scalars(&[1, 2, 3]); | ||
let ys = scalars(&[2, 4, 8]); | ||
let p = lagrange_naive(&xs, &ys); | ||
|
||
// Verify zeros. | ||
for (x, y) in zip(xs, ys) { | ||
assert_eq!(p.eval(&x), y); | ||
} | ||
|
||
// Verify degree. | ||
assert_eq!(p.highest_degree(), 2); | ||
} | ||
|
||
#[test] | ||
fn test_basis_polynomial_naive() { | ||
let xs = scalars(&[1, 2, 3]); | ||
|
||
for i in 0..xs.len() { | ||
let p = basis_polynomial_naive(i, &xs); | ||
|
||
// Verify points. | ||
for (j, x) in xs.iter().enumerate() { | ||
if j == i { | ||
assert_eq!(p.eval(x), scalar(1)); // L_i(x_i) = 1 | ||
} else { | ||
assert_eq!(p.eval(x), scalar(0)); // L_i(x_j) = 0 | ||
} | ||
} | ||
|
||
// Verify degree. | ||
assert_eq!(p.highest_degree(), 2); | ||
} | ||
} | ||
|
||
fn bench_lagrange_naive(b: &mut Bencher, n: usize) { | ||
let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); | ||
let xs = random_scalars(n, &mut rng); | ||
let ys = random_scalars(n, &mut rng); | ||
|
||
b.iter(|| { | ||
let _p = lagrange_naive(&xs, &ys); | ||
}); | ||
} | ||
|
||
#[bench] | ||
fn bench_lagrange_naive_1(b: &mut Bencher) { | ||
bench_lagrange_naive(b, 1) | ||
} | ||
|
||
#[bench] | ||
fn bench_lagrange_naive_2(b: &mut Bencher) { | ||
bench_lagrange_naive(b, 2) | ||
} | ||
|
||
#[bench] | ||
fn bench_lagrange_naive_5(b: &mut Bencher) { | ||
bench_lagrange_naive(b, 5) | ||
} | ||
|
||
#[bench] | ||
fn bench_lagrange_naive_10(b: &mut Bencher) { | ||
bench_lagrange_naive(b, 10) | ||
} | ||
|
||
#[bench] | ||
fn bench_lagrange_naive_20(b: &mut Bencher) { | ||
bench_lagrange_naive(b, 20) | ||
} | ||
} |
Oops, something went wrong.