Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andreistefanescu committed Dec 3, 2021
1 parent 10e6ee6 commit 715041c
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 65 deletions.
12 changes: 6 additions & 6 deletions spec/implementation/Types.cry
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ module implementation::Types where

type Vec256 = [4][64]
type Vec384 = [384]
//type Vec384 = [6][64]
type Vec384_x86 = [6][64]
type Vec512 = [8][64]
type Vec768 = [768]
//type Vec768 = [12][64]
type Vec768_x86 = [12][64]
type Pow256 = [32][8]
type Limb = [64]
type Size_t = [64]
Expand All @@ -32,11 +32,11 @@ vec_abs limbs = join (reverse limbs)
vec256_abs = vec_abs`{256/64}
vec384_abs : Vec384 -> Vec384
vec384_abs x = x
//vec384_abs = vec_abs`{384/64}
vec384_x86_abs = vec_abs`{384/64}
vec512_abs = vec_abs`{512/64}
vec768_abs : Vec768 -> Vec768
vec768_abs x = x
//vec768_abs = vec_abs`{768/64}
vec768_x86_abs = vec_abs`{768/64}

/**
* A vector representing a given (bitvector) integer. The integer should be non-negative.
Expand All @@ -49,10 +49,10 @@ vec_rep x = reverse (split x)
vec256_rep = vec_rep`{256/64}
vec384_rep : Vec384 -> Vec384
vec384_rep x = x
//vec384_rep = vec_rep`{384/64}
vec384_x86_rep = vec_rep`{384/64}
vec768_rep : Vec768 -> Vec768
vec768_rep x = x
//vec768_rep = vec_rep`{768/64}
vec768_x86_rep = vec_rep`{768/64}

/**
* Represent a Boolean as a limb, with value 1 for true, 0 for false
Expand Down
102 changes: 86 additions & 16 deletions spec/implementation/x86.cry
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,56 @@ mulx_mont_384x
-> Vec384 // m
-> [64] // n0
-> Vec768 // result
mulx_mont_384x a b m n0 = result
mulx_mont_384x a b m n0 = vec768_x86_abs (mulx_mont_384x_impl (vec768_x86_rep a) (vec768_x86_rep b) (vec384_x86_rep m) n0)
mulx_mont_384x_impl
: Vec768_x86 // a
-> Vec768_x86 // b
-> Vec384_x86 // m
-> [64] // n0
-> Vec768_x86 // result
mulx_mont_384x_impl a b m n0 = result
where
t0 = __mulx_384 (take a) (take b)
t1 = __mulx_384 (drop a) (drop b)
t2 = __mulx_384 (__add_mod_384 (take b) (drop b) m) (__add_mod_384 (take a) (drop a) m)
t2' = __sub_mod_384x384 (__sub_mod_384x384 t2 t0 m) t1 m
t0' = __sub_mod_384x384 t0 t1 m
ret_re = redcx_mont_384 t0' m n0
ret_im = redcx_mont_384 t2' m n0
ret_re = redcx_mont_384_impl t0' m n0
ret_im = redcx_mont_384_impl t2' m n0
result = ret_re # ret_im

sqrx_mont_384x
: Vec768 // a
-> Vec384 // m
-> [64] // n0
-> Vec768 // result
sqrx_mont_384x a m n0 = result
sqrx_mont_384x a m n0 = vec768_x86_abs (sqrx_mont_384x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0)
sqrx_mont_384x_impl
: Vec768_x86 // a
-> Vec384_x86 // m
-> [64] // n0
-> Vec768_x86 // result
sqrx_mont_384x_impl a m n0 = result
where
t0 = __add_mod_384 (take a) (drop a) m
t1 = __sub_mod_384 (take a) (drop a) m
tmp = mulx_mont_384 (take a) (drop a) m n0
tmp = mulx_mont_384_impl (take a) (drop a) m n0
ret_im = __add_mod_384 tmp tmp m
ret_re = mulx_mont_384 t0 t1 m n0
ret_re = mulx_mont_384_impl t0 t1 m n0
result = ret_re # ret_im

mulx_382x
: Vec768 // a
-> Vec768 // b
-> Vec384 // m
-> [2]Vec768 // result
mulx_382x a b m = result
mulx_382x a b m = map vec768_x86_abs (mulx_382x_impl (vec768_x86_rep a) (vec768_x86_rep b) (vec384_x86_rep m))
mulx_382x_impl
: Vec768_x86 // a
-> Vec768_x86 // b
-> Vec384_x86 // m
-> [2]Vec768_x86 // result
mulx_382x_impl a b m = result
where
a_re = take a
a_im = drop a
Expand Down Expand Up @@ -70,7 +89,12 @@ sqrx_382x
: Vec768 // a
-> Vec384 // m
-> [2]Vec768 // result
sqrx_382x a m = result
sqrx_382x a m = map vec768_x86_abs (sqrx_382x_impl (vec768_x86_rep a) (vec384_x86_rep m))
sqrx_382x_impl
: Vec768_x86 // a
-> Vec384_x86 // m
-> [2]Vec768_x86 // result
sqrx_382x_impl a m = result
where
a_re = take a
a_im = drop a
Expand Down Expand Up @@ -103,16 +127,29 @@ redcx_mont_384
-> Vec384 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec384 // result
redcx_mont_384 a m n0 = result
redcx_mont_384 a m n0 = vec384_x86_abs (redcx_mont_384_impl (vec768_x86_rep a) (vec384_x86_rep m) n0)
redcx_mont_384_impl
: Vec768_x86 // a; pointer in rsi
-> Vec384_x86 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec384_x86 // result
redcx_mont_384_impl a m n0 = result
where
acc = __mulx_by_1_mont_384 (take a) m n0
result = __redc_tail_mont_384 acc a m

fromx_mont_384
: Vec384 // a; pointer in rsi
-> Vec384 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec384 // result
fromx_mont_384 a m n0 = result
fromx_mont_384 a m n0 = vec384_x86_abs (fromx_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep m) n0)
fromx_mont_384_impl
: Vec384_x86 // a; pointer in rsi
-> Vec384_x86 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec384_x86 // result
fromx_mont_384_impl a m n0 = result
where

[r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = __mulx_by_1_mont_384 a m n0
Expand Down Expand Up @@ -141,7 +178,13 @@ sgn0x_pty_mont_384
-> Vec384 // p; passed in rsi
-> [64] // n0; passed in rdx
-> [64]
sgn0x_pty_mont_384 a p n0 = result
sgn0x_pty_mont_384 a p n0 = sgn0x_pty_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep p) n0
sgn0x_pty_mont_384_impl
: Vec384_x86 // a; passed in rdi
-> Vec384_x86 // p; passed in rsi
-> [64] // n0; passed in rdx
-> [64]
sgn0x_pty_mont_384_impl a p n0 = result
where

[r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = __mulx_by_1_mont_384 a p n0
Expand Down Expand Up @@ -171,7 +214,13 @@ sgn0x_pty_mont_384x
-> Vec384 // m; passed in rsi
-> [64] // n0; passed in rdx
-> [64]
sgn0x_pty_mont_384x a m n0 = result
sgn0x_pty_mont_384x a m n0 = sgn0x_pty_mont_384x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0
sgn0x_pty_mont_384x_impl
: Vec768_x86 // a; passed in rdi
-> Vec384_x86 // m; passed in rsi
-> [64] // n0; passed in rdx
-> [64]
sgn0x_pty_mont_384x_impl a m n0 = result
where
[r8_4, r9_4, r10_4, r11_4, r12_4, r13_4, r14_4, r15_4] = __mulx_by_1_mont_384 (drop a) m n0
r12_5 = r14_4
Expand Down Expand Up @@ -241,7 +290,14 @@ mulx_mont_384
-> Vec384 // m; passed in rcx
-> [64] // n0; passed in r8
-> Vec384 // result
mulx_mont_384 a b m n0 = result
mulx_mont_384 a b m n0 = vec384_x86_abs (mulx_mont_384_impl (vec384_x86_rep a) (vec384_x86_rep b) (vec384_x86_rep m) n0)
mulx_mont_384_impl
: Vec384_x86 // a; passed in rsi
-> Vec384_x86 // b; passed in rdx
-> Vec384_x86 // m; passed in rcx
-> [64] // n0; passed in r8
-> Vec384_x86 // result
mulx_mont_384_impl a b m n0 = result
where
(r9, r8) = mulx (a @ 0) (b @ 0)
acc =
Expand All @@ -266,22 +322,36 @@ sqrx_n_mul_mont_383
-> [64] // n0
-> Vec384 // b
-> Vec384 // result
sqrx_n_mul_mont_383 a count m n0 b = result
sqrx_n_mul_mont_383 a count m n0 b = vec384_x86_abs (sqrx_n_mul_mont_383_impl (vec384_x86_rep a) count (vec384_x86_rep m) n0 (vec384_x86_rep b))
sqrx_n_mul_mont_383_impl
: Vec384_x86 // a
-> Integer // count
-> Vec384_x86 // m
-> [64] // n0
-> Vec384_x86 // b
-> Vec384_x86 // result
sqrx_n_mul_mont_383_impl a count m n0 b = result
where
mulx_mont_383_nonred x y = __mulx_mont_383_nonred [r8, r9, undefined, undefined, x @ 3, undefined, x @ 0, x @ 1, x @ 2] (y @ 0) n0 (x @ 4) (x @ 5) x y m
where (r9, r8) = mulx (x @ 0) (y @ 0)
loop counter acc =
if counter == 0
then acc
else loop (counter - 1) (mulx_mont_383_nonred acc acc)
result = mulx_mont_384 (loop count a) b m n0
result = mulx_mont_384_impl (loop count a) b m n0

sqrx_mont_382x
: Vec768 // a; pointer in rsi
-> Vec384 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec768 // result
sqrx_mont_382x a m n0 = result
sqrx_mont_382x a m n0 = vec768_x86_abs (sqrx_mont_382x_impl (vec768_x86_rep a) (vec384_x86_rep m) n0)
sqrx_mont_382x_impl
: Vec768_x86 // a; pointer in rsi
-> Vec384_x86 // m; pointer in rdx
-> [64] // n0; passed in rcx
-> Vec768_x86 // result
sqrx_mont_382x_impl a m n0 = result
where
mulx_mont_383_nonred x y = __mulx_mont_383_nonred [r8, r9, undefined, undefined, x @ 3, undefined, x @ 0, x @ 1, x @ 2] (y @ 0) n0 (x @ 4) (x @ 5) x y m
where (r9, r8) = mulx (x @ 0) (y @ 0)
Expand Down
56 changes: 28 additions & 28 deletions spec/implementation/x86/helpers.cry
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ shrd_ret ret x shiftin count = if count && 63 == 0 then ret else drop ((shiftin

// Helper routines
__sub_mod_384x384
: Vec768 // a; pointer in rsi
-> Vec768 // b; pointer in rdx
-> Vec384 // n; pointer in rcx
-> Vec768 // result; pointer in rdi
: Vec768_x86 // a; pointer in rsi
-> Vec768_x86 // b; pointer in rdx
-> Vec384_x86 // n; pointer in rcx
-> Vec768_x86 // result; pointer in rdi
__sub_mod_384x384 a b n = result
where
[r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0] = take a
Expand Down Expand Up @@ -169,10 +169,10 @@ __sub_mod_384x384 a b n = result
result = [result0, result1, result2, result3, result4, result5, result6, result7, result8, result9, result10, result11]

__add_mod_384
: Vec384 // a
-> Vec384 // b
-> Vec384 // n
-> Vec384 // result
: Vec384_x86 // a
-> Vec384_x86 // b
-> Vec384_x86 // n
-> Vec384_x86 // result
__add_mod_384 a b n = result
where
(cf_0, tmp_0) = add (a @ 0) (b @ 0)
Expand All @@ -197,10 +197,10 @@ __add_mod_384 a b n = result
result = [result0, result1, result2, result3, result4, result5]

__sub_mod_384
: Vec384 // a
-> Vec384 // b
-> Vec384 // n
-> Vec384 // result
: Vec384_x86 // a
-> Vec384_x86 // b
-> Vec384_x86 // n
-> Vec384_x86 // result
__sub_mod_384 a b n = result
where
(cf_0, tmp0) = sub (a @ 0) (b @ 0)
Expand All @@ -226,9 +226,9 @@ __sub_mod_384 a b n = result

__redc_tail_mont_384
: [8][64] // acc registers
-> Vec768 // a; stored in rsi
-> Vec384 // n; pointer in rbx
-> Vec384 // result; stored in rdi
-> Vec768_x86 // a; stored in rsi
-> Vec384_x86 // n; pointer in rbx
-> Vec384_x86 // result; stored in rdi
__redc_tail_mont_384 acc a n = result
where
[r8_0, r9_0, r10_0, r11_0, r12_0, r13_0, r14_0, r15_0] = acc
Expand Down Expand Up @@ -267,9 +267,9 @@ __redc_tail_mont_384 acc a n = result
result = [result0, result1, result2, result3, result4, result5]

__mulx_384
: Vec384 // a; pointer stored in rsi
-> Vec384 // b; pointer stored in rbx
-> Vec768 // result; stored in rdi?
: Vec384_x86 // a; pointer stored in rsi
-> Vec384_x86 // b; pointer stored in rbx
-> Vec768_x86 // result; stored in rdi?
__mulx_384 a b = [result0, result1, result2, result3, result4, result5] # resulthigh
where
rdx_0 = b @ 0
Expand Down Expand Up @@ -404,8 +404,8 @@ __mulx_384 a b = [result0, result1, result2, result3, result4, result5] # result
resulthigh = [r8_106, r9_109, r10_112, r11_115, r12_119, r13_121]

__mulx_by_1_mont_384
: Vec384 // a; pointer stored in rsi
-> Vec384 // b; pointer stored in rbx
: Vec384_x86 // a; pointer stored in rsi
-> Vec384_x86 // b; pointer stored in rbx
-> [64] // n0; stored in rcx
-> [8][64] // result
__mulx_by_1_mont_384 a b n0 = result
Expand Down Expand Up @@ -569,10 +569,10 @@ __mulx_mont_384
-> [64] // stack[1]
-> [64] // lo; stored in rdi
-> [64] // hi; stored in rbp
-> Vec384 // a; pointer in rsi
-> Vec384 // b; pointer in rbx
-> Vec384 // n; pointer in rcx
-> Vec384 // result
-> Vec384_x86 // a; pointer in rsi
-> Vec384_x86 // b; pointer in rbx
-> Vec384_x86 // n; pointer in rcx
-> Vec384_x86 // result
__mulx_mont_384 acc b0 stack1_0 lo hi a b n = result
where
[r8_0, r9_0, _, _, r12_0, _, r14_0, r15_0, rax_0] = acc
Expand Down Expand Up @@ -917,10 +917,10 @@ __mulx_mont_383_nonred
-> [64] // stack[1]
-> [64] // lo; stored in rdi
-> [64] // hi; stored in rbp
-> Vec384 // a; pointer in rsi
-> Vec384 // b; pointer in rbx
-> Vec384 // n; pointer in rcx
-> Vec384 // result
-> Vec384_x86 // a; pointer in rsi
-> Vec384_x86 // b; pointer in rbx
-> Vec384_x86 // n; pointer in rcx
-> Vec384_x86 // result
__mulx_mont_383_nonred acc b0 stack1_0 lo hi a b n = result
where
[r8_0, r9_0, _, _, r12_0, _, r14_0, r15_0, rax_0] = acc
Expand Down
Loading

0 comments on commit 715041c

Please sign in to comment.