Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: avx2 acceleration #167

Merged
merged 7 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 97 additions & 51 deletions starky/src/arch/x86_64/avx2_poseidon_gl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ impl Default for Poseidon {
}
}

#[inline]
#[inline(always)]
unsafe fn spmv_avx_4x12(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let m = Avx2GoldilocksField::pack_slice(&m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2])
let m = Avx2GoldilocksField::pack_slice(m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2]);
}

impl Poseidon {
Expand Down Expand Up @@ -118,9 +118,9 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
c: Vec<FrRepr>,
c: &[FrRepr],
) {
let c = Avx2GoldilocksField::pack_slice(&c);
let c = Avx2GoldilocksField::pack_slice(c);
*st0 = *st0 + c[0];
*st1 = *st1 + c[1];
*st2 = *st2 + c[2];
Expand All @@ -132,9 +132,9 @@ impl Poseidon {
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
s0: Avx2GoldilocksField,
s: Vec<FrRepr>,
s: &[FrRepr],
) {
let s = Avx2GoldilocksField::pack_slice(&s);
let s = Avx2GoldilocksField::pack_slice(s);
*st0 = *st0 + s[0] * s0;
*st1 = *st1 + s[1] * s0;
*st2 = *st2 + s[2] * s0;
Expand All @@ -145,36 +145,36 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
p: Vec<FrRepr>,
p: &[FrRepr],
) {
let mut tmp0 = Avx2GoldilocksField::ZEROS;
let mut tmp1 = Avx2GoldilocksField::ZEROS;
let mut tmp2 = Avx2GoldilocksField::ZEROS;
Self::mmult_avx_4x12(&mut tmp0, *st0, *st1, *st2, p[0..48].to_vec());
Self::mmult_avx_4x12(&mut tmp1, *st0, *st1, *st2, p[48..96].to_vec());
Self::mmult_avx_4x12(&mut tmp2, *st0, *st1, *st2, p[96..144].to_vec());
Self::mmult_avx_4x12(&mut tmp0, *st0, *st1, *st2, &p[0..48]);
Self::mmult_avx_4x12(&mut tmp1, *st0, *st1, *st2, &p[48..96]);
Self::mmult_avx_4x12(&mut tmp2, *st0, *st1, *st2, &p[96..144]);
*st0 = tmp0;
*st1 = tmp1;
*st2 = tmp2;
}

// Dense matrix-vector product
#[inline]
#[inline(always)]
unsafe fn mmult_avx_4x12(
tmp: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut r0 = Avx2GoldilocksField::ZEROS;
let mut r1 = Avx2GoldilocksField::ZEROS;
let mut r2 = Avx2GoldilocksField::ZEROS;
let mut r3 = Avx2GoldilocksField::ZEROS;
spmv_avx_4x12(&mut r0, st0, st1, st2, m[0..12].to_vec());
spmv_avx_4x12(&mut r1, st0, st1, st2, m[12..24].to_vec());
spmv_avx_4x12(&mut r2, st0, st1, st2, m[24..36].to_vec());
spmv_avx_4x12(&mut r3, st0, st1, st2, m[36..48].to_vec());
spmv_avx_4x12(&mut r0, st0, st1, st2, &m[0..12]);
spmv_avx_4x12(&mut r1, st0, st1, st2, &m[12..24]);
spmv_avx_4x12(&mut r2, st0, st1, st2, &m[24..36]);
spmv_avx_4x12(&mut r3, st0, st1, st2, &m[36..48]);
// Transpose: transform de 4x4 matrix stored in rows r0...r3 to the columns c0...c3
let t0 = _mm256_permute2f128_si256(r0.get(), r2.get(), 0b00100000);
let t1 = _mm256_permute2f128_si256(r1.get(), r3.get(), 0b00100000);
Expand Down Expand Up @@ -205,36 +205,36 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut tmp0 = Avx2GoldilocksField::ZEROS;
let mut tmp1 = Avx2GoldilocksField::ZEROS;
let mut tmp2 = Avx2GoldilocksField::ZEROS;
Self::mmult_avx_4x12_8(&mut tmp0, *st0, *st1, *st2, m[0..48].to_vec());
Self::mmult_avx_4x12_8(&mut tmp1, *st0, *st1, *st2, m[48..96].to_vec());
Self::mmult_avx_4x12_8(&mut tmp2, *st0, *st1, *st2, m[96..144].to_vec());
Self::mmult_avx_4x12_8(&mut tmp0, *st0, *st1, *st2, &m[0..48]);
Self::mmult_avx_4x12_8(&mut tmp1, *st0, *st1, *st2, &m[48..96]);
Self::mmult_avx_4x12_8(&mut tmp2, *st0, *st1, *st2, &m[96..144]);
*st0 = tmp0;
*st1 = tmp1;
*st2 = tmp2;
}

// Dense matrix-vector product
#[inline]
#[inline(always)]
unsafe fn mmult_avx_4x12_8(
tmp: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut r0 = Avx2GoldilocksField::ZEROS;
let mut r1 = Avx2GoldilocksField::ZEROS;
let mut r2 = Avx2GoldilocksField::ZEROS;
let mut r3 = Avx2GoldilocksField::ZEROS;
Self::spmv_avx_4x12_8(&mut r0, st0, st1, st2, m[0..12].to_vec());
Self::spmv_avx_4x12_8(&mut r1, st0, st1, st2, m[12..24].to_vec());
Self::spmv_avx_4x12_8(&mut r2, st0, st1, st2, m[24..36].to_vec());
Self::spmv_avx_4x12_8(&mut r3, st0, st1, st2, m[36..48].to_vec());
Self::spmv_avx_4x12_8(&mut r0, st0, st1, st2, &m[0..12]);
Self::spmv_avx_4x12_8(&mut r1, st0, st1, st2, &m[12..24]);
Self::spmv_avx_4x12_8(&mut r2, st0, st1, st2, &m[24..36]);
Self::spmv_avx_4x12_8(&mut r3, st0, st1, st2, &m[36..48]);
// Transpose: transform de 4x4 matrix stored in rows r0...r3 to the columns c0...c3
let t0 = _mm256_permute2f128_si256(r0.get(), r2.get(), 0b00100000);
let t1 = _mm256_permute2f128_si256(r1.get(), r3.get(), 0b00100000);
Expand All @@ -260,13 +260,13 @@ impl Poseidon {
*tmp = c0 + c1 + c2 + c3;
}

#[inline]
#[inline(always)]
unsafe fn spmv_avx_4x12_8(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let m = Avx2GoldilocksField::pack_slice(&m);
let mut c0_h = Avx2GoldilocksField::ZEROS;
Expand All @@ -283,7 +283,7 @@ impl Poseidon {
*r = Avx2GoldilocksField::reduce(c_h.get(), c_l.get())
}

#[inline]
#[inline(always)]
unsafe fn mult_avx_72(
c_h: &mut Avx2GoldilocksField,
c_l: &mut Avx2GoldilocksField,
Expand Down Expand Up @@ -354,20 +354,20 @@ impl Poseidon {
let mut st1 = st[1];
let mut st2 = st[2];

Self::add_avx(&mut st0, &mut st1, &mut st2, (&C[0..12]).to_vec());
Self::add_avx(&mut st0, &mut st1, &mut st2, &C[0..12]);
for r in 0..(n_rounds_f / 2 - 1) {
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::add_avx(
&mut st0,
&mut st1,
&mut st2,
(&C[(r + 1) * 12..((r + 1) * 12 + 12)]).to_vec(),
&C[(r + 1) * 12..((r + 1) * 12 + 12)],
);
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, &M[0..144]);
}
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::add_avx(&mut st0, &mut st1, &mut st2, (&C[48..60]).to_vec());
Self::mmult_avx(&mut st0, &mut st1, &mut st2, (&P[0..144]).to_vec());
Self::add_avx(&mut st0, &mut st1, &mut st2, &C[48..60]);
Self::mmult_avx(&mut st0, &mut st1, &mut st2, &P[0..144]);

for r in 0..n_rounds_p {
let st0_slice = st0.as_slice_mut();
Expand All @@ -382,13 +382,7 @@ impl Poseidon {
st0_slice[0] = _st0.as_slice_mut()[0];

let mut tmp = Avx2GoldilocksField::ZEROS;
spmv_avx_4x12(
&mut tmp,
st0,
st1,
st2,
S[12 * 2 * r..(12 * 2 * r + 12)].to_vec(),
);
spmv_avx_4x12(&mut tmp, st0, st1, st2, &S[12 * 2 * r..(12 * 2 * r + 12)]);
let tmp_slice = tmp.as_slice_mut();
let sum = FGL::from_repr(tmp_slice[0]).unwrap()
+ FGL::from_repr(tmp_slice[1]).unwrap()
Expand All @@ -409,7 +403,7 @@ impl Poseidon {
&mut st1,
&mut st2,
*s0,
(&S[(12 * (2 * r + 1))..(12 * (2 * r + 2))]).to_vec(),
&S[(12 * (2 * r + 1))..(12 * (2 * r + 2))],
);

let st0_slice = st0.as_slice_mut();
Expand All @@ -422,14 +416,13 @@ impl Poseidon {
&mut st0,
&mut st1,
&mut st2,
(&C[((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t)
..((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + 12)])
.to_vec(),
&C[((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t)
..((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + 12)],
);
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, &M[0..144]);
}
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::mmult_avx(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx(&mut st0, &mut st1, &mut st2, &M[0..144]);

let st0_slice = st0.as_slice();

Expand All @@ -447,13 +440,19 @@ mod tests {
use algebraic::packed::PackedField;
use plonky::field_gl::Fr as FGL;
use plonky::PrimeField;
use std::time::{Duration, Instant};

#[test]
fn test_poseidon_opt_hash_all_0_avx() {
let poseidon = Poseidon::new();
let input = vec![FGL::ZERO; 8];
let state = vec![FGL::ZERO; 4];

let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
log::debug!("hash_avx2_duration_0: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0x3c18a9786cb0b359u64),
FGL::from(0xc4055e3364a246c3u64),
Expand All @@ -468,7 +467,11 @@ mod tests {
let poseidon = Poseidon::new();
let input = (0u64..8).map(FGL::from).collect::<Vec<FGL>>();
let state = (8u64..12).map(FGL::from).collect::<Vec<FGL>>();
let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
log::debug!("hash_avx2_duration_1: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0xd64e1e3efc5b8e9eu64),
FGL::from(0x53666633020aaa47u64),
Expand All @@ -484,7 +487,11 @@ mod tests {
let init = FGL::ZERO - FGL::ONE;
let input = vec![init; 8];
let state = vec![init; 4];
let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
log::debug!("hash_avx2_duration_2: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0xbe0085cfc57a8357u64),
FGL::from(0xd95af71847d05c09u64),
Expand All @@ -494,6 +501,45 @@ mod tests {
assert_eq!(res, expected);
}

#[test]
fn test_poseidon_opt_hash_1_11_avx_average() {
let poseidon = Poseidon::new();
let input = (0u64..8).map(FGL::from).collect::<Vec<FGL>>();
let state = (8u64..12).map(FGL::from).collect::<Vec<FGL>>();

let mut total_duration = Duration::new(0, 0);
let iterations = 100;

for _ in 0..iterations {
let start = Instant::now();
let _res = poseidon.hash(&input, &state, 4).unwrap();
total_duration += start.elapsed();
}

let average_duration = total_duration / iterations;
log::debug!("Average hash_avx2_duration_1: {:?}", average_duration);
}

#[test]
fn test_poseidon_opt_hash_all_neg_1_avx_average() {
let poseidon = Poseidon::new();
let init = FGL::ZERO - FGL::ONE;
let input = vec![init; 8];
let state = vec![init; 4];

let mut total_duration = Duration::new(0, 0);
let iterations = 100;

for _ in 0..iterations {
let start = Instant::now();
let _res = poseidon.hash(&input, &state, 4).unwrap();
total_duration += start.elapsed();
}

let average_duration = total_duration / iterations;
log::debug!("Average hash_avx2_duration_2: {:?}", average_duration);
}

#[test]
fn test_spmv_avx_4x12() {
let mut out = Avx2GoldilocksField::ZEROS;
Expand All @@ -516,9 +562,9 @@ mod tests {
FrRepr([18446744069414584320]),
]);

let in12 = vec![FrRepr([18446744069414584320]); 12];
let in12 = [FrRepr([18446744069414584320]); 12];
unsafe {
spmv_avx_4x12(&mut out, *in0, *in1, *in2, in12);
spmv_avx_4x12(&mut out, *in0, *in1, *in2, &in12);
};
let tmp_slice = out.as_slice_mut();
let _sum = FGL::from_repr(tmp_slice[0]).unwrap()
Expand Down
4 changes: 2 additions & 2 deletions starky/src/compressor12/compressor12_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ pub fn setup(
// construct and save ExecFile: plonk additions + sMap -> BigUint64Array
pub(super) fn write_exec_file(
exec_file: &str,
adds: &Vec<PlonkAdd>,
s_map: &Vec<Vec<u64>>,
adds: &[PlonkAdd],
s_map: &[Vec<u64>],
) -> Result<()> {
let adds_len = adds.len();
let s_map_row_len = s_map.len();
Expand Down
2 changes: 1 addition & 1 deletion starky/src/compressor12/plonk_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pub(crate) struct NormalPlonkInfo {
}

impl NormalPlonkInfo {
pub(crate) fn new(plonk_constrains: &Vec<PlonkGate>) -> Self {
pub(crate) fn new(plonk_constrains: &[PlonkGate]) -> Self {
let mut uses: BTreeMap<String, usize> = BTreeMap::new();
let plonk_constrains_len = plonk_constrains.len();
for (i, c) in plonk_constrains.iter().enumerate() {
Expand Down
6 changes: 3 additions & 3 deletions starky/src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ impl<F: FieldExtension> FFT<F> {
}
}

pub fn fft(&mut self, p: &Vec<F>) -> Vec<F> {
pub fn fft(&mut self, p: &[F]) -> Vec<F> {
if p.len() <= 1 {
return p.clone();
return p.to_owned();
}
let bits = log2_any(p.len() - 1) + 1;
self.set_roots(bits);
Expand Down Expand Up @@ -71,7 +71,7 @@ impl<F: FieldExtension> FFT<F> {
buff
}

pub fn ifft(&mut self, p: &Vec<F>) -> Vec<F> {
pub fn ifft(&mut self, p: &[F]) -> Vec<F> {
let q = self.fft(p);
let n = p.len();
let n2inv = F::from(p.len()).inv();
Expand Down
4 changes: 2 additions & 2 deletions starky/src/fri.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl FRI {
}
}

fn get_transposed_buffer<F: FieldExtension>(pol: &Vec<F>, transpose_bits: usize) -> Vec<FGL> {
fn get_transposed_buffer<F: FieldExtension>(pol: &[F], transpose_bits: usize) -> Vec<FGL> {
let n = pol.len();
let w = 1 << transpose_bits;
let h = n / w;
Expand All @@ -292,7 +292,7 @@ fn get3<F: FieldExtension>(arr: &[FGL], idx: usize) -> F {
}

// TODO: Support F5G
fn split3<F: FieldExtension>(arr: &Vec<FGL>) -> Vec<F> {
fn split3<F: FieldExtension>(arr: &[FGL]) -> Vec<F> {
let mut res: Vec<F> = Vec::new();
for i in (0..arr.len()).step_by(3) {
res.push(F::from_vec(vec![arr[i], arr[i + 1], arr[i + 2]]));
Expand Down
Loading
Loading