Skip to content

Commit

Permalink
chore: avx2 acceleration
Browse files Browse the repository at this point in the history
  • Loading branch information
ibmp33 committed Dec 4, 2023
1 parent 895e383 commit d86ee32
Showing 1 changed file with 97 additions and 51 deletions.
148 changes: 97 additions & 51 deletions starky/src/arch/x86_64/avx2_poseidon_gl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ impl Default for Poseidon {
}
}

#[inline]
#[inline(always)]
unsafe fn spmv_avx_4x12(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let m = Avx2GoldilocksField::pack_slice(&m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2])
let m = Avx2GoldilocksField::pack_slice(m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2]);
}

impl Poseidon {
Expand Down Expand Up @@ -118,9 +118,9 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
c: Vec<FrRepr>,
c: &[FrRepr],
) {
let c = Avx2GoldilocksField::pack_slice(&c);
let c = Avx2GoldilocksField::pack_slice(c);
*st0 = *st0 + c[0];
*st1 = *st1 + c[1];
*st2 = *st2 + c[2];
Expand All @@ -132,9 +132,9 @@ impl Poseidon {
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
s0: Avx2GoldilocksField,
s: Vec<FrRepr>,
s: &[FrRepr],
) {
let s = Avx2GoldilocksField::pack_slice(&s);
let s = Avx2GoldilocksField::pack_slice(s);
*st0 = *st0 + s[0] * s0;
*st1 = *st1 + s[1] * s0;
*st2 = *st2 + s[2] * s0;
Expand All @@ -145,36 +145,36 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
p: Vec<FrRepr>,
p: &[FrRepr],
) {
let mut tmp0 = Avx2GoldilocksField::ZEROS;
let mut tmp1 = Avx2GoldilocksField::ZEROS;
let mut tmp2 = Avx2GoldilocksField::ZEROS;
Self::mmult_avx_4x12(&mut tmp0, *st0, *st1, *st2, p[0..48].to_vec());
Self::mmult_avx_4x12(&mut tmp1, *st0, *st1, *st2, p[48..96].to_vec());
Self::mmult_avx_4x12(&mut tmp2, *st0, *st1, *st2, p[96..144].to_vec());
Self::mmult_avx_4x12(&mut tmp0, *st0, *st1, *st2, &p[0..48]);
Self::mmult_avx_4x12(&mut tmp1, *st0, *st1, *st2, &p[48..96]);
Self::mmult_avx_4x12(&mut tmp2, *st0, *st1, *st2, &p[96..144]);
*st0 = tmp0;
*st1 = tmp1;
*st2 = tmp2;
}

// Dense matrix-vector product
#[inline]
#[inline(always)]
unsafe fn mmult_avx_4x12(
tmp: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut r0 = Avx2GoldilocksField::ZEROS;
let mut r1 = Avx2GoldilocksField::ZEROS;
let mut r2 = Avx2GoldilocksField::ZEROS;
let mut r3 = Avx2GoldilocksField::ZEROS;
spmv_avx_4x12(&mut r0, st0, st1, st2, m[0..12].to_vec());
spmv_avx_4x12(&mut r1, st0, st1, st2, m[12..24].to_vec());
spmv_avx_4x12(&mut r2, st0, st1, st2, m[24..36].to_vec());
spmv_avx_4x12(&mut r3, st0, st1, st2, m[36..48].to_vec());
spmv_avx_4x12(&mut r0, st0, st1, st2, &m[0..12]);
spmv_avx_4x12(&mut r1, st0, st1, st2, &m[12..24]);
spmv_avx_4x12(&mut r2, st0, st1, st2, &m[24..36]);
spmv_avx_4x12(&mut r3, st0, st1, st2, &m[36..48]);
// Transpose: transform de 4x4 matrix stored in rows r0...r3 to the columns c0...c3
let t0 = _mm256_permute2f128_si256(r0.get(), r2.get(), 0b00100000);
let t1 = _mm256_permute2f128_si256(r1.get(), r3.get(), 0b00100000);
Expand Down Expand Up @@ -205,36 +205,36 @@ impl Poseidon {
st0: &mut Avx2GoldilocksField,
st1: &mut Avx2GoldilocksField,
st2: &mut Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut tmp0 = Avx2GoldilocksField::ZEROS;
let mut tmp1 = Avx2GoldilocksField::ZEROS;
let mut tmp2 = Avx2GoldilocksField::ZEROS;
Self::mmult_avx_4x12_8(&mut tmp0, *st0, *st1, *st2, m[0..48].to_vec());
Self::mmult_avx_4x12_8(&mut tmp1, *st0, *st1, *st2, m[48..96].to_vec());
Self::mmult_avx_4x12_8(&mut tmp2, *st0, *st1, *st2, m[96..144].to_vec());
Self::mmult_avx_4x12_8(&mut tmp0, *st0, *st1, *st2, &m[0..48]);
Self::mmult_avx_4x12_8(&mut tmp1, *st0, *st1, *st2, &m[48..96]);
Self::mmult_avx_4x12_8(&mut tmp2, *st0, *st1, *st2, &m[96..144]);
*st0 = tmp0;
*st1 = tmp1;
*st2 = tmp2;
}

// Dense matrix-vector product
#[inline]
#[inline(always)]
unsafe fn mmult_avx_4x12_8(
tmp: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let mut r0 = Avx2GoldilocksField::ZEROS;
let mut r1 = Avx2GoldilocksField::ZEROS;
let mut r2 = Avx2GoldilocksField::ZEROS;
let mut r3 = Avx2GoldilocksField::ZEROS;
Self::spmv_avx_4x12_8(&mut r0, st0, st1, st2, m[0..12].to_vec());
Self::spmv_avx_4x12_8(&mut r1, st0, st1, st2, m[12..24].to_vec());
Self::spmv_avx_4x12_8(&mut r2, st0, st1, st2, m[24..36].to_vec());
Self::spmv_avx_4x12_8(&mut r3, st0, st1, st2, m[36..48].to_vec());
Self::spmv_avx_4x12_8(&mut r0, st0, st1, st2, &m[0..12]);
Self::spmv_avx_4x12_8(&mut r1, st0, st1, st2, &m[12..24]);
Self::spmv_avx_4x12_8(&mut r2, st0, st1, st2, &m[24..36]);
Self::spmv_avx_4x12_8(&mut r3, st0, st1, st2, &m[36..48]);
// Transpose: transform de 4x4 matrix stored in rows r0...r3 to the columns c0...c3
let t0 = _mm256_permute2f128_si256(r0.get(), r2.get(), 0b00100000);
let t1 = _mm256_permute2f128_si256(r1.get(), r3.get(), 0b00100000);
Expand All @@ -260,13 +260,13 @@ impl Poseidon {
*tmp = c0 + c1 + c2 + c3;
}

#[inline]
#[inline(always)]
unsafe fn spmv_avx_4x12_8(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
m: &[FrRepr],
) {
let m = Avx2GoldilocksField::pack_slice(&m);
let mut c0_h = Avx2GoldilocksField::ZEROS;
Expand All @@ -283,7 +283,7 @@ impl Poseidon {
*r = Avx2GoldilocksField::reduce(c_h.get(), c_l.get())
}

#[inline]
#[inline(always)]
unsafe fn mult_avx_72(
c_h: &mut Avx2GoldilocksField,
c_l: &mut Avx2GoldilocksField,
Expand Down Expand Up @@ -354,20 +354,20 @@ impl Poseidon {
let mut st1 = st[1];
let mut st2 = st[2];

Self::add_avx(&mut st0, &mut st1, &mut st2, (&C[0..12]).to_vec());
Self::add_avx(&mut st0, &mut st1, &mut st2, &C[0..12]);
for r in 0..(n_rounds_f / 2 - 1) {
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::add_avx(
&mut st0,
&mut st1,
&mut st2,
(&C[(r + 1) * 12..((r + 1) * 12 + 12)]).to_vec(),
&C[(r + 1) * 12..((r + 1) * 12 + 12)],
);
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, &M[0..144]);
}
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::add_avx(&mut st0, &mut st1, &mut st2, (&C[48..60]).to_vec());
Self::mmult_avx(&mut st0, &mut st1, &mut st2, (&P[0..144]).to_vec());
Self::add_avx(&mut st0, &mut st1, &mut st2, &C[48..60]);
Self::mmult_avx(&mut st0, &mut st1, &mut st2, &P[0..144]);

for r in 0..n_rounds_p {
let st0_slice = st0.as_slice_mut();
Expand All @@ -382,13 +382,7 @@ impl Poseidon {
st0_slice[0] = _st0.as_slice_mut()[0];

let mut tmp = Avx2GoldilocksField::ZEROS;
spmv_avx_4x12(
&mut tmp,
st0,
st1,
st2,
S[12 * 2 * r..(12 * 2 * r + 12)].to_vec(),
);
spmv_avx_4x12(&mut tmp, st0, st1, st2, &S[12 * 2 * r..(12 * 2 * r + 12)]);
let tmp_slice = tmp.as_slice_mut();
let sum = FGL::from_repr(tmp_slice[0]).unwrap()
+ FGL::from_repr(tmp_slice[1]).unwrap()
Expand All @@ -409,7 +403,7 @@ impl Poseidon {
&mut st1,
&mut st2,
*s0,
(&S[(12 * (2 * r + 1))..(12 * (2 * r + 2))]).to_vec(),
&S[(12 * (2 * r + 1))..(12 * (2 * r + 2))],
);

let st0_slice = st0.as_slice_mut();
Expand All @@ -422,14 +416,13 @@ impl Poseidon {
&mut st0,
&mut st1,
&mut st2,
(&C[((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t)
..((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + 12)])
.to_vec(),
&C[((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t)
..((n_rounds_f / 2 + 1) * t + n_rounds_p + r * t + 12)],
);
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx_8(&mut st0, &mut st1, &mut st2, &M[0..144]);
}
Self::pow7_triple(&mut st0, &mut st1, &mut st2);
Self::mmult_avx(&mut st0, &mut st1, &mut st2, (&M[0..144]).to_vec());
Self::mmult_avx(&mut st0, &mut st1, &mut st2, &M[0..144]);

let st0_slice = st0.as_slice();

Expand All @@ -447,13 +440,19 @@ mod tests {
use algebraic::packed::PackedField;
use plonky::field_gl::Fr as FGL;
use plonky::PrimeField;
use std::time::{Duration, Instant};

#[test]
fn test_poseidon_opt_hash_all_0_avx() {
let poseidon = Poseidon::new();
let input = vec![FGL::ZERO; 8];
let state = vec![FGL::ZERO; 4];

let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
println!("hash_avx2_duration_0: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0x3c18a9786cb0b359u64),
FGL::from(0xc4055e3364a246c3u64),
Expand All @@ -468,7 +467,11 @@ mod tests {
let poseidon = Poseidon::new();
let input = (0u64..8).map(FGL::from).collect::<Vec<FGL>>();
let state = (8u64..12).map(FGL::from).collect::<Vec<FGL>>();
let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
println!("hash_avx2_duration_1: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0xd64e1e3efc5b8e9eu64),
FGL::from(0x53666633020aaa47u64),
Expand All @@ -484,7 +487,11 @@ mod tests {
let init = FGL::ZERO - FGL::ONE;
let input = vec![init; 8];
let state = vec![init; 4];
let start = Instant::now();
let res = poseidon.hash(&input, &state, 4).unwrap();
let hash_avx2_duration = start.elapsed();
println!("hash_avx2_duration_2: {:?}", hash_avx2_duration);

let expected = vec![
FGL::from(0xbe0085cfc57a8357u64),
FGL::from(0xd95af71847d05c09u64),
Expand All @@ -494,6 +501,45 @@ mod tests {
assert_eq!(res, expected);
}

#[test]
fn test_poseidon_opt_hash_1_11_avx_average() {
let poseidon = Poseidon::new();
let input = (0u64..8).map(FGL::from).collect::<Vec<FGL>>();
let state = (8u64..12).map(FGL::from).collect::<Vec<FGL>>();

let mut total_duration = Duration::new(0, 0);
let iterations = 100;

for _ in 0..iterations {
let start = Instant::now();
let _res = poseidon.hash(&input, &state, 4).unwrap();
total_duration += start.elapsed();
}

let average_duration = total_duration / iterations;
println!("Average hash_avx2_duration_1: {:?}", average_duration);
}

#[test]
fn test_poseidon_opt_hash_all_neg_1_avx_average() {
let poseidon = Poseidon::new();
let init = FGL::ZERO - FGL::ONE;
let input = vec![init; 8];
let state = vec![init; 4];

let mut total_duration = Duration::new(0, 0);
let iterations = 100;

for _ in 0..iterations {
let start = Instant::now();
let _res = poseidon.hash(&input, &state, 4).unwrap();
total_duration += start.elapsed();
}

let average_duration = total_duration / iterations;
println!("Average hash_avx2_duration_2: {:?}", average_duration);
}

#[test]
fn test_spmv_avx_4x12() {
let mut out = Avx2GoldilocksField::ZEROS;
Expand All @@ -516,9 +562,9 @@ mod tests {
FrRepr([18446744069414584320]),
]);

let in12 = vec![FrRepr([18446744069414584320]); 12];
let in12 = [FrRepr([18446744069414584320]); 12];
unsafe {
spmv_avx_4x12(&mut out, *in0, *in1, *in2, in12);
spmv_avx_4x12(&mut out, *in0, *in1, *in2, &in12);
};
let tmp_slice = out.as_slice_mut();
let _sum = FGL::from_repr(tmp_slice[0]).unwrap()
Expand Down

0 comments on commit d86ee32

Please sign in to comment.