Skip to content

Commit

Permalink
fix: add more sample to test overflow(mul/squre)
Browse files Browse the repository at this point in the history
  • Loading branch information
eigmax committed Nov 20, 2023
1 parent e5c380b commit c2af50b
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 42 deletions.
53 changes: 27 additions & 26 deletions algebraic/src/arch/x86_64/avx2_field_gl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,15 +516,15 @@ mod tests {

fn test_vals_a() -> [GoldilocksField; 4] {
[
GoldilocksField([14479013849828404771u64]),
GoldilocksField([18446744069414584320u64]),
GoldilocksField([9087029921428221768u64]),
GoldilocksField([2441288194761790662u64]),
GoldilocksField([5646033492608483824u64]),
]
}
fn test_vals_b() -> [GoldilocksField; 4] {
[
GoldilocksField([17891926589593242302u64]),
GoldilocksField([18446744069414584320u64]),
GoldilocksField([11009798273260028228u64]),
GoldilocksField([2028722748960791447u64]),
GoldilocksField([7929433601095175579u64]),
Expand All @@ -541,29 +541,30 @@ mod tests {
let packed_res = *packed_a + *packed_b;
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let expected = a_arr
.iter()
.zip(b_arr)
.map(|(&a, b)| Fr::from_repr(a).unwrap() + Fr::from_repr(b).unwrap());
let expected_values: Vec<Fr> = expected.collect();
// println!("expected values: {:?}", expected_values);
log::debug!("expected values: {:?}", expected_values[0].as_int());
let non_accelerated_duration = start.elapsed();
for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}

println!("test_add_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_add_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_add_Non_accelerated time: {:?}",
non_accelerated_duration
);
}

#[test]
fn test_mul() {
env_logger::try_init().unwrap_or_default();
let a_arr = test_vals_a();
let b_arr = test_vals_b();
let start = Instant::now();
Expand All @@ -572,7 +573,7 @@ mod tests {
let packed_res = packed_a * packed_b;
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let expected = a_arr
Expand All @@ -581,14 +582,14 @@ mod tests {
.map(|(&a, b)| Fr::from_repr(a).unwrap() * Fr::from_repr(b).unwrap());
let expected_values: Vec<Fr> = expected.collect();
let non_accelerated_duration = start.elapsed();
// println!("expected values: {:?}", expected_values);
log::debug!("expected values: {:?}", expected_values);

for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}

println!("test_mul_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_mul_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_mul_Non_accelerated time: {:?}",
non_accelerated_duration
);
Expand All @@ -602,7 +603,7 @@ mod tests {
let packed_res = packed_a / GoldilocksField([7929433601095175579u64]);
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let expected = a_arr.iter().map(|&a| {
Expand All @@ -611,14 +612,14 @@ mod tests {
});
let expected_values: Vec<Fr> = expected.collect();
let non_accelerated_duration = start.elapsed();
// println!("expected values: {:?}", expected_values);
// log::debug!("expected values: {:?}", expected_values);

for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}

println!("test_div_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_div_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_div_Non_accelerated time: {:?}",
non_accelerated_duration
);
Expand All @@ -632,7 +633,7 @@ mod tests {
let packed_res = packed_a.square();
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let mut expected_values = Vec::new();
Expand All @@ -648,12 +649,12 @@ mod tests {
}
}
let non_accelerated_duration = start.elapsed();
// println!("expected values: {:?}", expected_values);
// log::debug!("expected values: {:?}", expected_values);
for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}
println!("test_square_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_square_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_square_Non_accelerated time: {:?}",
non_accelerated_duration
);
Expand All @@ -667,20 +668,20 @@ mod tests {
let packed_res = -packed_a;
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let expected = a_arr.iter().map(|&a| -Fr::from_repr(a).unwrap());
let expected_values: Vec<Fr> = expected.collect();
let non_accelerated_duration = start.elapsed();
// println!("expected values: {:?}", expected_values);
// log::debug!("expected values: {:?}", expected_values);

for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}

println!("test_neg_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_neg_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_neg_Non_accelerated time: {:?}",
non_accelerated_duration
);
Expand All @@ -696,7 +697,7 @@ mod tests {
let packed_res = packed_a - packed_b;
let arr_res = packed_res.as_slice();
let avx2_duration = start.elapsed();
// println!("arr_res: {:?}", arr_res);
// log::debug!("arr_res: {:?}", arr_res);

let start = Instant::now();
let expected = a_arr
Expand All @@ -705,14 +706,14 @@ mod tests {
.map(|(&a, b)| Fr::from_repr(a).unwrap() - Fr::from_repr(b).unwrap());
let expected_values: Vec<Fr> = expected.collect();
let non_accelerated_duration = start.elapsed();
// println!("expected values: {:?}", expected_values);
// log::debug!("expected values: {:?}", expected_values);

for (exp, &res) in expected_values.iter().zip(arr_res) {
assert_eq!(res, exp.into_repr());
}

println!("test_sub_AVX2_accelerated time: {:?}", avx2_duration);
println!(
log::debug!("test_sub_AVX2_accelerated time: {:?}", avx2_duration);
log::debug!(
"test_sub_Non_accelerated time: {:?}",
non_accelerated_duration
);
Expand Down
67 changes: 51 additions & 16 deletions starky/src/arch/x86_64/avx2_poseidon_gl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ impl Default for Poseidon {
}
}


#[inline]
unsafe fn spmv_avx_4x12(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
) {
let m = Avx2GoldilocksField::pack_slice(&m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2])
}

impl Poseidon {
pub fn new() -> Poseidon {
Self {}
Expand Down Expand Up @@ -159,10 +172,10 @@ impl Poseidon {
let mut r1 = Avx2GoldilocksField::ZEROS;
let mut r2 = Avx2GoldilocksField::ZEROS;
let mut r3 = Avx2GoldilocksField::ZEROS;
Self::spmv_avx_4x12(&mut r0, st0, st1, st2, m[0..12].to_vec());
Self::spmv_avx_4x12(&mut r1, st0, st1, st2, m[12..24].to_vec());
Self::spmv_avx_4x12(&mut r2, st0, st1, st2, m[24..36].to_vec());
Self::spmv_avx_4x12(&mut r3, st0, st1, st2, m[36..48].to_vec());
spmv_avx_4x12(&mut r0, st0, st1, st2, m[0..12].to_vec());
spmv_avx_4x12(&mut r1, st0, st1, st2, m[12..24].to_vec());
spmv_avx_4x12(&mut r2, st0, st1, st2, m[24..36].to_vec());
spmv_avx_4x12(&mut r3, st0, st1, st2, m[36..48].to_vec());
// Transpose: transform de 4x4 matrix stored in rows r0...r3 to the columns c0...c3
let t0 = _mm256_permute2f128_si256(r0.get(), r2.get(), 0b00100000);
let t1 = _mm256_permute2f128_si256(r1.get(), r3.get(), 0b00100000);
Expand All @@ -188,17 +201,6 @@ impl Poseidon {
*tmp = c0 + c1 + c2 + c3;
}

#[inline]
unsafe fn spmv_avx_4x12(
r: &mut Avx2GoldilocksField,
st0: Avx2GoldilocksField,
st1: Avx2GoldilocksField,
st2: Avx2GoldilocksField,
m: Vec<FrRepr>,
) {
let m = Avx2GoldilocksField::pack_slice(&m);
*r = (st0 * m[0]) + (st1 * m[1]) + (st2 * m[2])
}

#[inline(always)]
unsafe fn mmult_avx_8(
Expand Down Expand Up @@ -382,7 +384,7 @@ impl Poseidon {
st0_slice[0] = _st0.as_slice_mut()[0];

let mut tmp = Avx2GoldilocksField::ZEROS;
Self::spmv_avx_4x12(
spmv_avx_4x12(
&mut tmp,
st0,
st1,
Expand Down Expand Up @@ -493,4 +495,37 @@ mod tests {
];
assert_eq!(res, expected);
}

#[test]
fn test_spmv_avx_4x12() {
let mut out = Avx2GoldilocksField::ZEROS;
let mut in0 = Avx2GoldilocksField::from_slice(&[
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
]);
let mut in1 = Avx2GoldilocksField::from_slice(&[
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
]);
let mut in2 = Avx2GoldilocksField::from_slice(&[
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
FrRepr([18446744069414584320]),
]);

let in12 = vec![FrRepr([18446744069414584320]); 12];
unsafe {
spmv_avx_4x12(&mut out, *in0, *in1, *in2, in12);
};
let tmp_slice = out.as_slice_mut();
let _sum = FGL::from_repr(tmp_slice[0]).unwrap()
+ FGL::from_repr(tmp_slice[1]).unwrap()
+ FGL::from_repr(tmp_slice[2]).unwrap()
+ FGL::from_repr(tmp_slice[3]).unwrap();
}
}

0 comments on commit c2af50b

Please sign in to comment.