From 715041c76cc8614ff3254d44e0db3b8726ae97a7 Mon Sep 17 00:00:00 2001 From: Andrei Stefanescu Date: Fri, 3 Dec 2021 02:14:17 +0000 Subject: [PATCH] wip --- spec/implementation/Types.cry | 12 ++-- spec/implementation/x86.cry | 102 +++++++++++++++++++++++----- spec/implementation/x86/helpers.cry | 56 +++++++-------- spec/implementation/x86/inverse.cry | 30 ++++---- 4 files changed, 135 insertions(+), 65 deletions(-) diff --git a/spec/implementation/Types.cry b/spec/implementation/Types.cry index 0ce112ac..398a9736 100644 --- a/spec/implementation/Types.cry +++ b/spec/implementation/Types.cry @@ -11,10 +11,10 @@ module implementation::Types where type Vec256 = [4][64] type Vec384 = [384] -//type Vec384 = [6][64] +type Vec384_x86 = [6][64] type Vec512 = [8][64] type Vec768 = [768] -//type Vec768 = [12][64] +type Vec768_x86 = [12][64] type Pow256 = [32][8] type Limb = [64] type Size_t = [64] @@ -32,11 +32,11 @@ vec_abs limbs = join (reverse limbs) vec256_abs = vec_abs`{256/64} vec384_abs : Vec384 -> Vec384 vec384_abs x = x -//vec384_abs = vec_abs`{384/64} +vec384_x86_abs = vec_abs`{384/64} vec512_abs = vec_abs`{512/64} vec768_abs : Vec768 -> Vec768 vec768_abs x = x -//vec768_abs = vec_abs`{768/64} +vec768_x86_abs = vec_abs`{768/64} /** * A vector representing a given (bitvector) integer. The integer should be non-negative. @@ -49,10 +49,10 @@ vec_rep x = reverse (split x) vec256_rep = vec_rep`{256/64} vec384_rep : Vec384 -> Vec384 vec384_rep x = x -//vec384_rep = vec_rep`{384/64} +vec384_x86_rep = vec_rep`{384/64} vec768_rep : Vec768 -> Vec768 vec768_rep x = x -//vec768_rep = vec_rep`{768/64} +vec768_x86_rep = vec_rep`{768/64} /** * Represent a Boolean as a limb, with value 1 for true, 0 for false diff --git a/spec/implementation/x86.cry b/spec/implementation/x86.cry index f00f1f31..f24d318e 100644 --- a/spec/implementation/x86.cry +++ b/spec/implementation/x86.cry @@ -10,15 +10,22 @@ mulx_mont_384x -> Vec384 // m -> [64] // n0 -> Vec768 // result -mulx_mont_384x a b m n0 = result +mulx_mont_384x a b m n0 = vec768_x86_abs (mulx_mont_384x_impl (vec768_x86_rep a) (vec768_x86_rep b) (vec384_x86_rep m) n0) +mulx_mont_384x_impl + : Vec768_x86 // a + -> Vec768_x86 // b + -> Vec384_x86 // m + -> [64] // n0 + -> Vec768_x86 // result +mulx_mont_384x_impl a b m n0 = result where t0 = __mulx_384 (take a) (take b) t1 = __mulx_384 (drop a) (drop b) t2 = __mulx_384 (__add_mod_384 (take b) (drop b) m) (__add_mod_384 (take a) (drop a) m) t2' = __sub_mod_384x384 (__sub_mod_384x384 t2 t0 m) t1 m t0' = __sub_mod_384x384 t0 t1 m - ret_re = redcx_mont_384 t0' m n0 - ret_im = redcx_mont_384 t2' m n0 + ret_re = redcx_mont_384_impl t0' m n0 + ret_im = redcx_mont_384_impl t2' m n0 result = ret_re # ret_im sqrx_mont_384x @@ -26,13 +33,19 @@ sqrx_mont_384x -> Vec384 // m -> [64] // n0 -> Vec768 // result -sqrx_mont_384x a m n0 = result +sqrx_mont_384x a m n0 = vec768_x86_abs (sqrx_mont_384x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0) +sqrx_mont_384x_impl + : Vec768_x86 // a + -> Vec384_x86 // m + -> [64] // n0 + -> Vec768_x86 // result +sqrx_mont_384x_impl a m n0 = result where t0 = __add_mod_384 (take a) (drop a) m t1 = __sub_mod_384 (take a) (drop a) m - tmp = mulx_mont_384 (take a) (drop a) m n0 + tmp = mulx_mont_384_impl (take a) (drop a) m n0 ret_im = __add_mod_384 tmp tmp m - ret_re = mulx_mont_384 t0 t1 m n0 + ret_re = mulx_mont_384_impl t0 t1 m n0 result = ret_re # ret_im mulx_382x @@ -40,7 +53,13 @@ mulx_382x -> Vec768 // b -> Vec384 // m -> [2]Vec768 // result -mulx_382x a b m = result +mulx_382x a b m = map vec768_x86_abs (mulx_382x_impl (vec768_x86_rep a) (vec768_x86_rep b) (vec384_x86_rep m)) +mulx_382x_impl + : Vec768_x86 // a + -> Vec768_x86 // b + -> Vec384_x86 // m + -> [2]Vec768_x86 // result +mulx_382x_impl a b m = result where a_re = take a a_im = drop a @@ -70,7 +89,12 @@ sqrx_382x : Vec768 // a -> Vec384 // m -> [2]Vec768 // result -sqrx_382x a m = result +sqrx_382x a m = map vec768_x86_abs (sqrx_382x_impl (vec768_x86_rep a) (vec384_x86_rep m)) +sqrx_382x_impl + : Vec768_x86 // a + -> Vec384_x86 // m + -> [2]Vec768_x86 // result +sqrx_382x_impl a m = result where a_re = take a a_im = drop a @@ -103,16 +127,29 @@ redcx_mont_384 -> Vec384 // m; pointer in rdx -> [64] // n0; passed in rcx -> Vec384 // result -redcx_mont_384 a m n0 = result +redcx_mont_384 a m n0 = vec384_x86_abs (redcx_mont_384_impl (vec768_x86_rep a) (vec384_x86_rep m) n0) +redcx_mont_384_impl + : Vec768_x86 // a; pointer in rsi + -> Vec384_x86 // m; pointer in rdx + -> [64] // n0; passed in rcx + -> Vec384_x86 // result +redcx_mont_384_impl a m n0 = result where acc = __mulx_by_1_mont_384 (take a) m n0 result = __redc_tail_mont_384 acc a m + fromx_mont_384 : Vec384 // a; pointer in rsi -> Vec384 // m; pointer in rdx -> [64] // n0; passed in rcx -> Vec384 // result -fromx_mont_384 a m n0 = result +fromx_mont_384 a m n0 = vec384_x86_abs (fromx_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep m) n0) +fromx_mont_384_impl + : Vec384_x86 // a; pointer in rsi + -> Vec384_x86 // m; pointer in rdx + -> [64] // n0; passed in rcx + -> Vec384_x86 // result +fromx_mont_384_impl a m n0 = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = __mulx_by_1_mont_384 a m n0 @@ -141,7 +178,13 @@ sgn0x_pty_mont_384 -> Vec384 // p; passed in rsi -> [64] // n0; passed in rdx -> [64] -sgn0x_pty_mont_384 a p n0 = result +sgn0x_pty_mont_384 a p n0 = sgn0x_pty_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep p) n0 +sgn0x_pty_mont_384_impl + : Vec384_x86 // a; passed in rdi + -> Vec384_x86 // p; passed in rsi + -> [64] // n0; passed in rdx + -> [64] +sgn0x_pty_mont_384_impl a p n0 = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = __mulx_by_1_mont_384 a p n0 @@ -171,7 +214,13 @@ sgn0x_pty_mont_384x -> Vec384 // m; passed in rsi -> [64] // n0; passed in rdx -> [64] -sgn0x_pty_mont_384x a m n0 = result +sgn0x_pty_mont_384x a m n0 = sgn0x_pty_mont_384x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0 +sgn0x_pty_mont_384x_impl + : Vec768_x86 // a; passed in rdi + -> Vec384_x86 // m; passed in rsi + -> [64] // n0; passed in rdx + -> [64] +sgn0x_pty_mont_384x_impl a m n0 = result where [r8_4, r9_4, r10_4, r11_4, r12_4, r13_4, r14_4, r15_4] = __mulx_by_1_mont_384 (drop a) m n0 r12_5 = r14_4 @@ -241,7 +290,14 @@ mulx_mont_384 -> Vec384 // m; passed in rcx -> [64] // n0; passed in r8 -> Vec384 // result -mulx_mont_384 a b m n0 = result +mulx_mont_384 a b m n0 = vec384_x86_abs (mulx_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep b) (vec384_x86_rep m) n0) +mulx_mont_384_impl + : Vec384_x86 // a; passed in rsi + -> Vec384_x86 // b; passed in rdx + -> Vec384_x86 // m; passed in rcx + -> [64] // n0; passed in r8 + -> Vec384_x86 // result +mulx_mont_384_impl a b m n0 = result where (r9, r8) = mulx (a @ 0) (b @ 0) acc = @@ -266,7 +322,15 @@ sqrx_n_mul_mont_383 -> [64] // n0 -> Vec384 // b -> Vec384 // result -sqrx_n_mul_mont_383 a count m n0 b = result +sqrx_n_mul_mont_383 a count m n0 b = vec384_x86_abs (sqrx_n_mul_mont_383_impl (vec384_x86_rep a) count (vec384_x86_rep m) n0 (vec384_x86_rep b)) +sqrx_n_mul_mont_383_impl + : Vec384_x86 // a + -> Integer // count + -> Vec384_x86 // m + -> [64] // n0 + -> Vec384_x86 // b + -> Vec384_x86 // result +sqrx_n_mul_mont_383_impl a count m n0 b = result where mulx_mont_383_nonred x y = __mulx_mont_383_nonred [r8, r9, undefined, undefined, x @ 3, undefined, x @ 0, x @ 1, x @ 2] (y @ 0) n0 (x @ 4) (x @ 5) x y m where (r9, r8) = mulx (x @ 0) (y @ 0) @@ -274,14 +338,20 @@ sqrx_n_mul_mont_383 a count m n0 b = result if counter == 0 then acc else loop (counter - 1) (mulx_mont_383_nonred acc acc) - result = mulx_mont_384 (loop count a) b m n0 + result = mulx_mont_384_impl (loop count a) b m n0 sqrx_mont_382x : Vec768 // a; pointer in rsi -> Vec384 // m; pointer in rdx -> [64] // n0; passed in rcx -> Vec768 // result -sqrx_mont_382x a m n0 = result +sqrx_mont_382x a m n0 = vec768_x86_abs (sqrx_mont_382x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0) +sqrx_mont_382x_impl + : Vec768_x86 // a; pointer in rsi + -> Vec384_x86 // m; pointer in rdx + -> [64] // n0; passed in rcx + -> Vec768_x86 // result +sqrx_mont_382x_impl a m n0 = result where mulx_mont_383_nonred x y = __mulx_mont_383_nonred [r8, r9, undefined, undefined, x @ 3, undefined, x @ 0, x @ 1, x @ 2] (y @ 0) n0 (x @ 4) (x @ 5) x y m where (r9, r8) = mulx (x @ 0) (y @ 0) diff --git a/spec/implementation/x86/helpers.cry b/spec/implementation/x86/helpers.cry index 92ecd4da..19bc2a8e 100644 --- a/spec/implementation/x86/helpers.cry +++ b/spec/implementation/x86/helpers.cry @@ -111,10 +111,10 @@ shrd_ret ret x shiftin count = if count && 63 == 0 then ret else drop ((shiftin // Helper routines __sub_mod_384x384 - : Vec768 // a; pointer in rsi - -> Vec768 // b; pointer in rdx - -> Vec384 // n; pointer in rcx - -> Vec768 // result; pointer in rdi + : Vec768_x86 // a; pointer in rsi + -> Vec768_x86 // b; pointer in rdx + -> Vec384_x86 // n; pointer in rcx + -> Vec768_x86 // result; pointer in rdi __sub_mod_384x384 a b n = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0] = take a @@ -169,10 +169,10 @@ __sub_mod_384x384 a b n = result result = [result0, result1, result2, result3, result4, result5, result6, result7, result8, result9, result10, result11] __add_mod_384 - : Vec384 // a - -> Vec384 // b - -> Vec384 // n - -> Vec384 // result + : Vec384_x86 // a + -> Vec384_x86 // b + -> Vec384_x86 // n + -> Vec384_x86 // result __add_mod_384 a b n = result where (cf_0, tmp_0) = add (a @ 0) (b @ 0) @@ -197,10 +197,10 @@ __add_mod_384 a b n = result result = [result0, result1, result2, result3, result4, result5] __sub_mod_384 - : Vec384 // a - -> Vec384 // b - -> Vec384 // n - -> Vec384 // result + : Vec384_x86 // a + -> Vec384_x86 // b + -> Vec384_x86 // n + -> Vec384_x86 // result __sub_mod_384 a b n = result where (cf_0, tmp0) = sub (a @ 0) (b @ 0) @@ -226,9 +226,9 @@ __sub_mod_384 a b n = result __redc_tail_mont_384 : [8][64] // acc registers - -> Vec768 // a; stored in rsi - -> Vec384 // n; pointer in rbx - -> Vec384 // result; stored in rdi + -> Vec768_x86 // a; stored in rsi + -> Vec384_x86 // n; pointer in rbx + -> Vec384_x86 // result; stored in rdi __redc_tail_mont_384 acc a n = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = acc @@ -267,9 +267,9 @@ __redc_tail_mont_384 acc a n = result result = [result0, result1, result2, result3, result4, result5] __mulx_384 - : Vec384 // a; pointer stored in rsi - -> Vec384 // b; pointer stored in rbx - -> Vec768 // result; stored in rdi? + : Vec384_x86 // a; pointer stored in rsi + -> Vec384_x86 // b; pointer stored in rbx + -> Vec768_x86 // result; stored in rdi? __mulx_384 a b = [result0, result1, result2, result3, result4, result5] # resulthigh where rdx_0 = b @ 0 @@ -404,8 +404,8 @@ __mulx_384 a b = [result0, result1, result2, result3, result4, result5] # result resulthigh = [r8_106, r9_109, r10_112, r11_115, r12_119, r13_121] __mulx_by_1_mont_384 - : Vec384 // a; pointer stored in rsi - -> Vec384 // b; pointer stored in rbx + : Vec384_x86 // a; pointer stored in rsi + -> Vec384_x86 // b; pointer stored in rbx -> [64] // n0; stored in rcx -> [8][64] // result __mulx_by_1_mont_384 a b n0 = result @@ -569,10 +569,10 @@ __mulx_mont_384 -> [64] // stack[1] -> [64] // lo; stored in rdi -> [64] // hi; stored in rbp - -> Vec384 // a; pointer in rsi - -> Vec384 // b; pointer in rbx - -> Vec384 // n; pointer in rcx - -> Vec384 // result + -> Vec384_x86 // a; pointer in rsi + -> Vec384_x86 // b; pointer in rbx + -> Vec384_x86 // n; pointer in rcx + -> Vec384_x86 // result __mulx_mont_384 acc b0 stack1_0 lo hi a b n = result where [r8_0, r9_0, _, _, r12_0, _, r14_0, r15_0, rax_0] = acc @@ -917,10 +917,10 @@ __mulx_mont_383_nonred -> [64] // stack[1] -> [64] // lo; stored in rdi -> [64] // hi; stored in rbp - -> Vec384 // a; pointer in rsi - -> Vec384 // b; pointer in rbx - -> Vec384 // n; pointer in rcx - -> Vec384 // result + -> Vec384_x86 // a; pointer in rsi + -> Vec384_x86 // b; pointer in rbx + -> Vec384_x86 // n; pointer in rcx + -> Vec384_x86 // result __mulx_mont_383_nonred acc b0 stack1_0 lo hi a b n = result where [r8_0, r9_0, _, _, r12_0, _, r14_0, r15_0, rax_0] = acc diff --git a/spec/implementation/x86/inverse.cry b/spec/implementation/x86/inverse.cry index 5c0a549f..be4b6cc1 100644 --- a/spec/implementation/x86/inverse.cry +++ b/spec/implementation/x86/inverse.cry @@ -4,11 +4,11 @@ import implementation::Types import implementation::x86::helpers __smulx_383_n_shift_by_31 - : Vec384 // a, stored at low offsets of rsi - -> Vec384 // b, stored at high offsets of rsi + : Vec384_x86 // a, stored at low offsets of rsi + -> Vec384_x86 // b, stored at high offsets of rsi -> [64] // f0 -> [64] // g0 - -> (Vec384, [64], [64]) + -> (Vec384_x86, [64], [64]) __smulx_383_n_shift_by_31 a b f0 g0 = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0] = a @@ -214,11 +214,11 @@ __smulx_191_n_shift_by_31 a b f0 g0 = result result = ([result0, result1, result2], rdx_66, rcx_67) __smulx_383x63 - : Vec384 // u, stored at low offsets of rsi - -> Vec384 // v, stored at high offsets of rsi + : Vec384_x86 // u, stored at low offsets of rsi + -> Vec384_x86 // v, stored at high offsets of rsi -> [64] // f0 -> [64] // g0 - -> Vec384 + -> Vec384_x86 __smulx_383x63 u v f0 g0 = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0] = u @@ -310,11 +310,11 @@ __smulx_383x63 u v f0 g0 = result result = [result0, result1, result2, result3, result4, result5] __smulx_767x63 - : Vec384 // u, stored at low offsets of rsi - -> Vec768 // v, stored at high offsets of rsi + : Vec384_x86 // u, stored at low offsets of rsi + -> Vec768_x86 // v, stored at high offsets of rsi -> [64] // f0 -> [64] // g0 - -> Vec768 + -> Vec768_x86 __smulx_767x63 u v f0 g0 = result where [r8_0, r9_0, r10_0, r11_0, r12_0, r13_0] = u @@ -458,8 +458,8 @@ __smulx_767x63 u v f0 g0 = result __ab_approximation_31 : Integer // number of iterations - -> Vec384 // a - -> Vec384 // b + -> Vec384_x86 // a + -> Vec384_x86 // b -> [4][64] // result; f0, g0, f1, g1 __ab_approximation_31 iters a b = result where @@ -618,9 +618,9 @@ __inner_loop_62 iters a_init b_init = result loopresult = if curcounter == 0 then curresult else loop curcounter curresult ctx_inverse_mod_383 - : Vec384 // a; pointer in rsi - -> Vec384 // n; pointer in rdx - -> Vec768 // ret + : Vec384_x86 // a; pointer in rsi + -> Vec384_x86 // n; pointer in rdx + -> Vec768_x86 // ret ctx_inverse_mod_383 a n = result // (earlyresult # zero) // result where a_2 = a @@ -802,5 +802,5 @@ ctx_inverse_mod_383 a n = result // (earlyresult # zero) // result result9 = rbp_124 result10 = rcx_124 result11 = rdx_124 - result : Vec768 + result : Vec768_x86 result = resultlow # [result6, result7, result8, result9, result10, result11] \ No newline at end of file