Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
weikengchen committed Aug 15, 2024
1 parent cb13941 commit 6c7abef
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 392 deletions.
4 changes: 2 additions & 2 deletions crates/prover/benches/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ use num_traits::Zero;
use stwo_prover::core::backend::simd::SimdBackend;
use stwo_prover::core::backend::{Col, CpuBackend};
use stwo_prover::core::fields::m31::{BaseField, N_BYTES_FELT};
use stwo_prover::core::vcs::sha256_merkle::Sha256MerkleChannel;
use stwo_prover::core::vcs::sha256_merkle::Sha256MerkleHasher;
use stwo_prover::core::vcs::ops::MerkleOps;

const LOG_N_ROWS: u32 = 16;

const LOG_N_COLS: u32 = 8;

fn bench_sha256_merkle<B: MerkleOps<Sha256MerkleChannel>>(c: &mut Criterion, id: &str) {
fn bench_sha256_merkle<B: MerkleOps<Sha256MerkleHasher>>(c: &mut Criterion, id: &str) {
let col: Col<B, BaseField> = (0..1 << LOG_N_ROWS).map(|_| BaseField::zero()).collect();
let cols = (0..1 << LOG_N_COLS).map(|_| col.clone()).collect_vec();
let col_refs = cols.iter().collect_vec();
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/benches/pcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ use stwo_prover::core::poly::circle::{CanonicCoset, CircleEvaluation};
use stwo_prover::core::poly::twiddles::TwiddleTree;
use stwo_prover::core::poly::BitReversedOrder;
use stwo_prover::core::vcs::sha256_hash::Sha256Hash;
use stwo_prover::core::vcs::sha256_merkle::Sha256MerkleChannel;
use stwo_prover::core::vcs::sha256_merkle::Sha256MerkleHasher;

const LOG_COSET_SIZE: u32 = 20;
const LOG_BLOWUP_FACTOR: u32 = 1;
const N_POLYS: usize = 16;

fn benched_fn<B: BackendForChannel<Sha256MerkleChannel>>(
fn benched_fn<B: BackendForChannel<Sha256MerkleHasher>>(
evals: Vec<CircleEvaluation<B, BaseField, BitReversedOrder>>,
channel: &mut Sha256Channel,
twiddles: &TwiddleTree<B>,
Expand All @@ -28,15 +28,15 @@ fn benched_fn<B: BackendForChannel<Sha256MerkleChannel>>(
.map(|eval| eval.interpolate_with_twiddles(twiddles))
.collect();

CommitmentTreeProver::<B, Sha256MerkleChannel>::new(
CommitmentTreeProver::<B, Sha256MerkleHasher>::new(
polys,
LOG_BLOWUP_FACTOR,
channel,
twiddles,
);
}

fn bench_pcs<B: BackendForChannel<Sha256MerkleChannel>>(c: &mut Criterion, id: &str) {
fn bench_pcs<B: BackendForChannel<Sha256MerkleHasher>>(c: &mut Criterion, id: &str) {
let small_domain = CanonicCoset::new(LOG_COSET_SIZE);
let big_domain = CanonicCoset::new(LOG_COSET_SIZE + LOG_BLOWUP_FACTOR);
let twiddles = B::precompute_twiddles(big_domain.half_coset());
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/core/backend/cpu/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ use itertools::Itertools;
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::vcs::sha256_hash::Sha256Hash;
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;
use crate::core::vcs::sha256_merkle::Sha256MerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};

impl MerkleOps<Sha256MerkleChannel> for CpuBackend {
impl MerkleOps<Sha256MerkleHasher> for CpuBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Sha256Hash>>,
columns: &[&Vec<BaseField>],
) -> Vec<Sha256Hash> {
(0..(1 << log_size))
.map(|i| {
Sha256MerkleChannel::hash_node(
Sha256MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column[i]).collect_vec(),
)
Expand Down
66 changes: 11 additions & 55 deletions crates/prover/src/core/backend/simd/grind.rs
Original file line number Diff line number Diff line change
@@ -1,65 +1,21 @@
use std::simd::cmp::SimdPartialOrd;
use std::simd::num::SimdUint;
use std::simd::u32x16;

use bytemuck::cast_slice;
#[cfg(feature = "parallel")]
use rayon::prelude::*;

use super::blake2s::compress16;
use super::SimdBackend;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::channel::Blake2sChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::channel::{Channel, Poseidon252Channel};
use crate::core::channel::Sha256Channel;
use crate::core::proof_of_work::GrindOps;

// Note: GRIND_LOW_BITS is a cap on how much extra time we need to wait for all threads to finish.
const GRIND_LOW_BITS: u32 = 20;
const GRIND_HI_BITS: u32 = 64 - GRIND_LOW_BITS;

impl GrindOps<Blake2sChannel> for SimdBackend {
fn grind(channel: &Blake2sChannel, pow_bits: u32) -> u64 {
// TODO(spapini): support more than 32 bits.
assert!(pow_bits <= 32, "pow_bits > 32 is not supported");
let digest = channel.digest();
let digest: &[u32] = cast_slice(&digest.0[..]);

#[cfg(not(feature = "parallel"))]
let res = (0..=(1 << GRIND_HI_BITS))
.find_map(|hi| grind_blake(digest, hi, pow_bits))
.expect("Grind failed to find a solution.");

#[cfg(feature = "parallel")]
let res = (0..=(1 << GRIND_HI_BITS))
.into_par_iter()
.find_map_any(|hi| grind_blake(digest, hi, pow_bits))
.expect("Grind failed to find a solution.");

res
}
}

fn grind_blake(digest: &[u32], hi: u64, pow_bits: u32) -> Option<u64> {
let zero: u32x16 = u32x16::default();
let pow_bits = u32x16::splat(pow_bits);

let state: [u32x16; 8] = std::array::from_fn(|i| u32x16::splat(digest[i]));

let mut attempt = [zero; 16];
attempt[0] = u32x16::splat((hi << GRIND_LOW_BITS) as u32);
attempt[0] += u32x16::from(std::array::from_fn(|i| i as u32));
attempt[1] = u32x16::splat((hi >> (32 - GRIND_LOW_BITS)) as u32);
for low in (0..(1 << GRIND_LOW_BITS)).step_by(N_LANES) {
let res = compress16(state, attempt, zero, zero, zero, zero);
let success_mask = res[0].trailing_zeros().simd_ge(pow_bits);
if success_mask.any() {
let i = success_mask.to_array().iter().position(|&x| x).unwrap();
return Some((hi << GRIND_LOW_BITS) + low as u64 + i as u64);
impl GrindOps<Sha256Channel> for SimdBackend {
fn grind(channel: &Sha256Channel, pow_bits: u32) -> u64 {
let mut nonce = 0;
loop {
let mut channel = channel.clone();
channel.mix_nonce(nonce);
if channel.trailing_zeros() >= pow_bits {
return nonce;
}
nonce += 1;
}
attempt[0] += u32x16::splat(N_LANES as u32);
}
None
}

// TODO(spapini): This is a naive implementation. Optimize it.
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/core/backend/simd/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::vcs::sha256_hash::Sha256Hash;
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;
use crate::core::vcs::sha256_merkle::Sha256MerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};

impl ColumnOps<Sha256Hash> for SimdBackend {
Expand All @@ -18,7 +18,7 @@ impl ColumnOps<Sha256Hash> for SimdBackend {
}

// TODO(BWS): not simd at all
impl MerkleOps<Sha256MerkleChannel> for SimdBackend {
impl MerkleOps<Sha256MerkleHasher> for SimdBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Sha256Hash>>,
Expand All @@ -31,7 +31,7 @@ impl MerkleOps<Sha256MerkleChannel> for SimdBackend {
let iter = (0..1 << log_size).into_par_iter();

iter.map(|i| {
Sha256MerkleChannel::hash_node(
Sha256MerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
Expand Down
5 changes: 2 additions & 3 deletions crates/prover/src/core/channel/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Sha256Channel {
}

impl Sha256Channel {
pub fn digest(&self) -> Self::Digest {
pub fn digest(&self) -> Sha256Hash {
self.digest
}

Expand All @@ -36,7 +36,7 @@ impl Channel for Sha256Channel {
let mut hasher = Sha256::new();
Digest::update(&mut hasher, sha256_qm31(felt));
Digest::update(&mut hasher, self.digest);
self.update_digest(hasher.finalize().into());
self.update_digest(hasher.finalize().as_slice().into());
}
}

Expand Down Expand Up @@ -117,7 +117,6 @@ mod tests {

use crate::core::channel::{Sha256Channel, Channel};
use crate::core::fields::qm31::SecureField;
use crate::core::vcs::sha256_hash::Sha256Hash;
use crate::m31;

#[test]
Expand Down
40 changes: 1 addition & 39 deletions crates/prover/src/core/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,23 +212,8 @@ impl<B: FriOps + MerkleOps<MC::H>, MC: MerkleChannel> FriProver<B, MC> {
);

while layer_evaluation.len() > config.last_layer_domain_size() {
<<<<<<< HEAD
// Check for any columns (circle poly evaluations) that should be combined.
while let Some(column) = columns.next_if(|c| folded_len(c) == layer_evaluation.len()) {
B::fold_circle_into_line(
&mut layer_evaluation,
column,
circle_poly_alpha,
twiddles,
);
}

let layer = FriLayerProver::new(layer_evaluation);
MC::mix_root(channel, layer.merkle_tree.root());
=======
let layer = FriLayerProver::new(layer_evaluation);
channel.mix_digest(layer.merkle_tree.root());
>>>>>>> b8a18a5 (Cumulative changes to stwo)
let folding_alpha = channel.draw_felt();
let folded_layer_evaluation = B::fold_line(&layer.evaluation, folding_alpha, twiddles);

Expand Down Expand Up @@ -364,11 +349,7 @@ impl<MC: MerkleChannel> FriVerifier<MC> {
));

for (layer_index, proof) in proof.inner_layers.into_iter().enumerate() {
<<<<<<< HEAD
MC::mix_root(channel, proof.commitment);
=======
channel.mix_digest(proof.commitment);
>>>>>>> b8a18a5 (Cumulative changes to stwo)

let folding_alpha = channel.draw_felt();

Expand Down Expand Up @@ -465,22 +446,15 @@ impl<MC: MerkleChannel> FriVerifier<MC> {
.next_if(|b| b.fold_to_line() == layer.degree_bound)
.is_some()
{
<<<<<<< HEAD
=======
assert!(!insertion); // enforce that this can only be performed once
insertion = true;

>>>>>>> b8a18a5 (Cumulative changes to stwo)
let sparse_evaluation = decommitted_values.next().unwrap();
let folded_evals = sparse_evaluation.fold(circle_poly_alpha);
assert_eq!(folded_evals.len(), layer_query_evals.len());

for (layer_eval, folded_eval) in zip(&mut layer_query_evals, folded_evals) {
<<<<<<< HEAD
*layer_eval = *layer_eval * circle_poly_alpha_sq + folded_eval;
=======
*layer_eval = folded_eval;
>>>>>>> b8a18a5 (Cumulative changes to stwo)
}
}

Expand Down Expand Up @@ -864,10 +838,6 @@ impl<B: FriOps + MerkleOps<H>, H: MerkleHasher> FriLayerProver<B, H> {
.collect(),
self.evaluation.values.columns.iter().collect_vec(),
);
<<<<<<< HEAD
=======
// let decomposition_coeff = self.decomposition_coeff;
>>>>>>> b8a18a5 (Cumulative changes to stwo)

FriLayerProof {
evals_subset,
Expand Down Expand Up @@ -1038,21 +1008,13 @@ mod tests {
use crate::core::queries::{Queries, SparseSubCircleDomain};
use crate::core::test_utils::test_channel;
use crate::core::utils::bit_reverse_index;
<<<<<<< HEAD
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
=======
use crate::core::vcs::sha256_merkle::Sha256MerkleChannel;
>>>>>>> b8a18a5 (Cumulative changes to stwo)

/// Default blowup factor used for tests.
const LOG_BLOWUP_FACTOR: u32 = 2;

<<<<<<< HEAD
type FriProver = super::FriProver<CpuBackend, Blake2sMerkleChannel>;
type FriVerifier = super::FriVerifier<Blake2sMerkleChannel>;
=======
type FriProver = super::FriProver<CpuBackend, Sha256MerkleChannel>;
>>>>>>> b8a18a5 (Cumulative changes to stwo)
type FriVerifier = super::FriVerifier<Sha256MerkleChannel>;

#[test]
fn fold_line_works() {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/pcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub struct PcsConfig {
impl Default for PcsConfig {
fn default() -> Self {
Self {
pow_bits: 5,
fri_config: FriConfig::new(0, 1, 3),
pow_bits: 20,
fri_config: FriConfig::new(0, 10, 8),
}
}
}
Loading

0 comments on commit 6c7abef

Please sign in to comment.