Skip to content

Commit

Permalink
Fixup KEM combiner
Browse files Browse the repository at this point in the history
  • Loading branch information
david415 committed Sep 2, 2024
1 parent e6a4d3c commit 03448e6
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
1 change: 1 addition & 0 deletions CryptWalker.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ import CryptWalker.protocol.merkle_tree
import CryptWalker.kem.kem
import CryptWalker.kem.adapter
import CryptWalker.kem.schemes
import CryptWalker.kem.combiner
import CryptWalker.sign.ed25519
77 changes: 34 additions & 43 deletions CryptWalker/kem/combiner.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ SPDX-License-Identifier: AGPL-3.0-only
-/

import Mathlib.Data.ByteArray

import CryptWalker.kem.kem

namespace CryptWalker.kem.combiner
Expand All @@ -17,7 +16,7 @@ SplitPRF can be used with any number of KEMs
and it implement split PRF KEM combiner as:
cct := cct1 || cct2 || cct3 || ...
return H(ss1 || cct) XOR H(ss2 || cct) XOR H(ss3 || cct)
return H(ss1 || cct) XOR H(ss2 || cct) XOR H(ss3 || cct)
in order to retain IND-CCA2 security
as described in KEM Combiners https://eprint.iacr.org/2018/024.pdf
Expand All @@ -28,7 +27,7 @@ def hashSize := 32

def xorByteArrays (a b : ByteArray) : ByteArray :=
if a.size ≠ b.size then
panic! "ByteArrays must be of equal size"
panic! "xorByteArrays: ByteArrays must be of equal size"
else
ByteArray.mk (Array.zipWith a.data b.data fun x y => x ^^^ y)

Expand All @@ -37,7 +36,7 @@ def splitPRF (hash : ByteArray → ByteArray) (ss : List ByteArray) (ct : List B
panic! "splitPRF failure: mismatched List lengths"
else
let bigCt : ByteArray := ct.foldl (fun acc blob => acc ++ blob) ByteArray.empty
(ss.map (fun x => hash (x ++ bigCt))).foldl (fun acc h => xorByteArrays acc h) (ByteArray.mkEmpty hashSize)
(ss.map (fun x => hash (x ++ bigCt))).foldl (fun acc h => xorByteArrays acc h) (ByteArray.mk (Array.mkArray hashSize 0))

structure PrivateKey where
data : List ByteArray
Expand Down Expand Up @@ -66,75 +65,67 @@ def splitByteArrayIntoChunks (bytes : ByteArray) (sizes : List Nat) : Option (Li
aux part2 sizesTail (part1 :: acc)
aux bytes sizes []

def createKEMCombiner (name : String) (hash : ByteArray → ByteArray) (KEMs : List KEM) : KEM :=
{
PublicKeyType := PublicKey,
PrivateKeyType := PrivateKey,
privateKeySize := KEMs.foldl (fun acc x => acc + x.privateKeySize) 0,
publicKeySize := KEMs.foldl (fun acc x => acc + x.publicKeySize) 0,
ciphertextSize := KEMs.foldl (fun acc x => acc + x.ciphertextSize) 0,
name := name,

structure Combiner where
hash : ByteArray → ByteArray
KEMs : List KEM

instance (combiner : Combiner) (name : String) : KEM where
PublicKeyType := PublicKey
PrivateKeyType := PrivateKey

privateKeySize := combiner.KEMs.foldl (fun acc x => acc + x.privateKeySize) 0
publicKeySize := combiner.KEMs.foldl (fun acc x => acc + x.publicKeySize) 0
ciphertextSize := combiner.KEMs.foldl (fun acc x => acc + x.ciphertextSize) 0

name : String := name

generateKeyPair : IO (PublicKey × PrivateKey) := do
generateKeyPair := do
let mut pubkeyData : List ByteArray := []
let mut privkeyData : List ByteArray := []
for kem in combiner.KEMs do
for kem in KEMs do
let (newpubkey, newprivkey) ← kem.generateKeyPair
pubkeyData := pubkeyData ++ [kem.encodePublicKey newpubkey]
privkeyData := privkeyData ++ [kem.encodePrivateKey newprivkey]
pure ({ data := pubkeyData }, { data := privkeyData })
pure ({ data := pubkeyData }, { data := privkeyData }),

encapsulate : PublicKey → IO (ByteArray × ByteArray) := fun pubkey => do
encapsulate := fun pubkey => do
let mut sharedSecrets : List ByteArray := []
let mut ciphertexts : List ByteArray := []
let mut ciphertext : ByteArray := ByteArray.mkEmpty 0
for (kem, pubKeyChunk) in combiner.KEMs.zip pubkey.data do
let mut ciphertext : ByteArray := ByteArray.empty
for (kem, pubKeyChunk) in KEMs.zip pubkey.data do
match kem.decodePublicKey pubKeyChunk with
| none => panic! "failed to decode pub key"
| some pubkey =>
let (ct, ss) ← kem.encapsulate pubkey
sharedSecrets := sharedSecrets ++ [ss]
ciphertexts := ciphertexts ++ [ct]
ciphertext := ciphertext ++ ct
pure (ciphertext, splitPRF combiner.hash sharedSecrets ciphertexts)
pure (ciphertext, splitPRF hash sharedSecrets ciphertexts),

decapsulate : PrivateKey → ByteArray → ByteArray := fun privkey ciphertext =>
let sizes : List Nat := combiner.KEMs.foldl (fun acc x => acc ++ [x.ciphertextSize]) []
decapsulate := fun privkey ciphertext =>
let sizes := KEMs.map (fun x => x.ciphertextSize)
match splitByteArrayIntoChunks ciphertext sizes with
| none => panic! "failed to parse ciphertext"
| some ciphertexts =>
let pairs := List.zip combiner.KEMs ciphertexts
let pairs3 := List.zip pairs privkey.data
let sharedSecrets : List ByteArray := pairs3.map (fun x =>
match x.fst.fst.decodePrivateKey x.fst.snd with
let sharedSecrets := KEMs.zip ciphertexts |>.zip privkey.data |>.map (fun ((kem, ct), privKeyChunk) =>
match kem.decodePrivateKey privKeyChunk with
| none => panic! "decode private key failure"
| some innerPrivkey =>
x.fst.fst.decapsulate innerPrivkey x.fst.snd
| some innerPrivkey => kem.decapsulate innerPrivkey ct
)
splitPRF combiner.hash sharedSecrets ciphertexts
splitPRF hash sharedSecrets ciphertexts

encodePrivateKey : PrivateKey → ByteArray := fun privkey =>
privkey.data.foldl (fun acc key => acc ++ key) ByteArray.empty
encodePrivateKey := fun privkey =>
privkey.data.foldl (fun acc key => acc ++ key) ByteArray.empty,

decodePrivateKey : ByteArray → Option PrivateKey := fun bytes =>
let sizes : List Nat := combiner.KEMs.map (fun kem => kem.privateKeySize)
decodePrivateKey := fun bytes =>
let sizes : List Nat := KEMs.map (fun kem => kem.privateKeySize)
match splitByteArrayIntoChunks bytes sizes with
| none => none
| some keys => some { data := keys }
| some keys => some { data := keys },

encodePublicKey : PublicKey → ByteArray := fun pubkey =>
pubkey.data.foldl (fun acc key => acc ++ key) ByteArray.empty
encodePublicKey := fun pubkey =>
pubkey.data.foldl (fun acc key => acc ++ key) ByteArray.empty,

decodePublicKey : ByteArray → Option PublicKey := fun bytes =>
let sizes : List Nat := combiner.KEMs.map (fun kem => kem.publicKeySize)
decodePublicKey := fun bytes =>
let sizes : List Nat := KEMs.map (fun kem => kem.publicKeySize)
match splitByteArrayIntoChunks bytes sizes with
| none => none
| some keys => some { data := keys }
}

end CryptWalker.kem.combiner
2 changes: 1 addition & 1 deletion CryptWalker/kem/kem.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Batteries.Classes.SatisfiesM

namespace CryptWalker.kem.kem

structure KEM where
class KEM where
PublicKeyType : Type
PrivateKeyType : Type

Expand Down
10 changes: 9 additions & 1 deletion CryptWalker/kem/schemes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,29 @@ import CryptWalker.nike.x41417
import CryptWalker.nike.nike
import CryptWalker.kem.kem
import CryptWalker.kem.adapter
import CryptWalker.kem.combiner
import CryptWalker.hash.Sha2

open CryptWalker.nike
open CryptWalker.nike.nike
open CryptWalker.kem.kem
open CryptWalker.kem.adapter
open CryptWalker.kem.combiner
open CryptWalker.hash.Sha2

namespace CryptWalker.kem.schemes

def kemX25519 := toKEM $ Adapter.mk Sha256.hash x25519.Scheme
def kemX448 := toKEM $ Adapter.mk Sha256.hash x448.Scheme
def kem41417 := toKEM $ Adapter.mk Sha256.hash x41417.Scheme
def combinedKEM := createKEMCombiner "CombinedKEM" Sha256.hash [kemX25519, kemX448, kem41417]

def Schemes : List KEM :=
[
toKEM $ Adapter.mk Sha256.hash x25519.Scheme,
toKEM $ Adapter.mk Sha256.hash x448.Scheme,
toKEM $ Adapter.mk Sha256.hash x41417.Scheme
toKEM $ Adapter.mk Sha256.hash x41417.Scheme,
combinedKEM
]

end CryptWalker.kem.schemes

0 comments on commit 03448e6

Please sign in to comment.