From 274d30c6e359f5cbc3487a75e4ddf06c5afae975 Mon Sep 17 00:00:00 2001 From: Basil Hess Date: Mon, 9 Dec 2024 17:46:47 +0100 Subject: [PATCH] Integrate ML-KEM from mlkem-native [full tests] Signed-off-by: Basil Hess --- .CMake/alg_support.cmake | 30 +- docs/algorithms/kem/ml_kem.md | 5 +- docs/algorithms/kem/ml_kem.yml | 51 +- docs/cbom.json | 95 +- .../copy_from_upstream/copy_from_upstream.yml | 11 +- src/common/pqclean_shims/fips202x4.h | 5 + src/kem/ml_kem/CMakeLists.txt | 90 +- src/kem/ml_kem/kem_ml_kem_1024.c | 76 +- src/kem/ml_kem/kem_ml_kem_512.c | 76 +- src/kem/ml_kem/kem_ml_kem_768.c | 76 +- .../LICENSE | 0 .../aarch64/README.md | 19 + .../aarch64/clean.h | 24 + .../aarch64/opt.h | 24 + .../aarch64/src/aarch64_zetas.c | 175 ++ .../aarch64/src/arith_native_aarch64.h | 90 + .../aarch64/src/clean_impl.h | 80 + .../aarch64/src/consts.h | 19 + .../aarch64/src/intt_clean.S | 364 ++++ .../aarch64/src/intt_opt.S | 1020 +++++++++++ .../aarch64/src/ntt_clean.S | 283 +++ .../aarch64/src/ntt_opt.S | 919 ++++++++++ .../aarch64/src/opt_impl.h | 81 + .../aarch64/src/optimize.sh | 121 ++ .../aarch64/src/poly_clean.S | 331 ++++ .../aarch64/src/poly_opt.S | 690 +++++++ .../aarch64/src/polyvec_clean.S | 288 +++ .../aarch64/src/polyvec_opt.S | 1584 +++++++++++++++++ .../aarch64/src/rej_uniform_asm_clean.S | 341 ++++ .../aarch64/src/rej_uniform_table.c | 288 +++ .../mlkem-native_ml-kem-1024_aarch64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-1024_aarch64/cbd.c | 156 ++ .../mlkem-native_ml-kem-1024_aarch64/cbd.h | 54 + .../mlkem-native_ml-kem-1024_aarch64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-1024_aarch64/common.h | 65 + .../mlkem-native_ml-kem-1024_aarch64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../default.h | 32 + .../mlkem-native_ml-kem-1024_aarch64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-1024_aarch64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-1024_aarch64/kem.c | 195 ++ .../mlkem-native_ml-kem-1024_aarch64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-1024_aarch64/ntt.c | 268 +++ .../mlkem-native_ml-kem-1024_aarch64/ntt.h | 103 ++ .../mlkem-native_ml-kem-1024_aarch64/params.h | 64 + .../mlkem-native_ml-kem-1024_aarch64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-1024_aarch64/poly.h | 805 +++++++++ .../polyvec.c | 172 ++ .../polyvec.h | 332 ++++ .../mlkem-native_ml-kem-1024_aarch64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-1024_aarch64/sys.h | 109 ++ .../mlkem-native_ml-kem-1024_aarch64/verify.c | 20 + .../mlkem-native_ml-kem-1024_aarch64/verify.h | 317 ++++ .../mlkem-native_ml-kem-1024_aarch64/zetas.c | 30 + .../LICENSE | 0 .../ml_kem/mlkem-native_ml-kem-1024_ref/api.h | 255 +++ .../arith_backend.h | 22 + .../ml_kem/mlkem-native_ml-kem-1024_ref/cbd.c | 156 ++ .../ml_kem/mlkem-native_ml-kem-1024_ref/cbd.h | 54 + .../mlkem-native_ml-kem-1024_ref/cbmc.h | 139 ++ .../mlkem-native_ml-kem-1024_ref/common.h | 65 + .../mlkem-native_ml-kem-1024_ref/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-1024_ref/default.h | 32 + .../mlkem-native_ml-kem-1024_ref/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-1024_ref/indcpa.h | 117 ++ .../ml_kem/mlkem-native_ml-kem-1024_ref/kem.c | 195 ++ .../ml_kem/mlkem-native_ml-kem-1024_ref/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../ml_kem/mlkem-native_ml-kem-1024_ref/ntt.c | 268 +++ .../ml_kem/mlkem-native_ml-kem-1024_ref/ntt.h | 103 ++ .../mlkem-native_ml-kem-1024_ref/params.h | 64 + .../mlkem-native_ml-kem-1024_ref/poly.c | 583 ++++++ .../mlkem-native_ml-kem-1024_ref/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-1024_ref/polyvec.c | 172 ++ .../mlkem-native_ml-kem-1024_ref/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-1024_ref/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../mlkem-native_ml-kem-1024_ref/symmetric.h | 52 + .../ml_kem/mlkem-native_ml-kem-1024_ref/sys.h | 109 ++ .../mlkem-native_ml-kem-1024_ref/verify.c | 20 + .../mlkem-native_ml-kem-1024_ref/verify.h | 317 ++++ .../mlkem-native_ml-kem-1024_ref/zetas.c | 30 + .../LICENSE | 0 .../mlkem-native_ml-kem-1024_x86_64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-1024_x86_64/cbd.c | 156 ++ .../mlkem-native_ml-kem-1024_x86_64/cbd.h | 54 + .../mlkem-native_ml-kem-1024_x86_64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-1024_x86_64/common.h | 65 + .../mlkem-native_ml-kem-1024_x86_64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-1024_x86_64/default.h | 32 + .../mlkem-native_ml-kem-1024_x86_64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-1024_x86_64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-1024_x86_64/kem.c | 195 ++ .../mlkem-native_ml-kem-1024_x86_64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-1024_x86_64/ntt.c | 268 +++ .../mlkem-native_ml-kem-1024_x86_64/ntt.h | 103 ++ .../mlkem-native_ml-kem-1024_x86_64/params.h | 64 + .../mlkem-native_ml-kem-1024_x86_64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-1024_x86_64/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-1024_x86_64/polyvec.c | 172 ++ .../mlkem-native_ml-kem-1024_x86_64/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-1024_x86_64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-1024_x86_64/sys.h | 109 ++ .../mlkem-native_ml-kem-1024_x86_64/verify.c | 20 + .../mlkem-native_ml-kem-1024_x86_64/verify.h | 317 ++++ .../x86_64/README.md | 4 + .../x86_64/default.h | 24 + .../x86_64/src/align.h | 31 + .../x86_64/src/arith_native_x86_64.h | 59 + .../x86_64/src}/basemul.S | 41 +- .../x86_64/src/basemul.c | 68 + .../x86_64/src/consts.c | 93 + .../x86_64/src/consts.h | 44 + .../x86_64/src/default_impl.h | 97 + .../x86_64/src}/fq.S | 61 +- .../x86_64/src}/fq.inc | 12 + .../x86_64/src/intt.S | 255 +++ .../x86_64/src/ntt.S | 219 +++ .../x86_64/src/rej_uniform_avx2.c | 131 ++ .../x86_64/src/rej_uniform_table.c | 159 ++ .../x86_64/src}/shuffle.S | 56 +- .../x86_64/src}/shuffle.inc | 16 +- .../x86_64/src/x86_64_zetas.i | 56 + .../mlkem-native_ml-kem-1024_x86_64/zetas.c | 30 + .../LICENSE | 0 .../aarch64/README.md | 19 + .../aarch64/clean.h | 24 + .../aarch64/opt.h | 24 + .../aarch64/src/aarch64_zetas.c | 175 ++ .../aarch64/src/arith_native_aarch64.h | 90 + .../aarch64/src/clean_impl.h | 80 + .../aarch64/src/consts.h | 19 + .../aarch64/src/intt_clean.S | 364 ++++ .../aarch64/src/intt_opt.S | 1020 +++++++++++ .../aarch64/src/ntt_clean.S | 283 +++ .../aarch64/src/ntt_opt.S | 919 ++++++++++ .../aarch64/src/opt_impl.h | 81 + .../aarch64/src/optimize.sh | 121 ++ .../aarch64/src/poly_clean.S | 331 ++++ .../aarch64/src/poly_opt.S | 690 +++++++ .../aarch64/src/polyvec_clean.S | 288 +++ .../aarch64/src/polyvec_opt.S | 1584 +++++++++++++++++ .../aarch64/src/rej_uniform_asm_clean.S | 341 ++++ .../aarch64/src/rej_uniform_table.c | 288 +++ .../mlkem-native_ml-kem-512_aarch64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-512_aarch64/cbd.c | 156 ++ .../mlkem-native_ml-kem-512_aarch64/cbd.h | 54 + .../mlkem-native_ml-kem-512_aarch64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-512_aarch64/common.h | 65 + .../mlkem-native_ml-kem-512_aarch64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-512_aarch64/default.h | 32 + .../mlkem-native_ml-kem-512_aarch64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-512_aarch64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-512_aarch64/kem.c | 195 ++ .../mlkem-native_ml-kem-512_aarch64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-512_aarch64/ntt.c | 268 +++ .../mlkem-native_ml-kem-512_aarch64/ntt.h | 103 ++ .../mlkem-native_ml-kem-512_aarch64/params.h | 64 + .../mlkem-native_ml-kem-512_aarch64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-512_aarch64/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-512_aarch64/polyvec.c | 172 ++ .../mlkem-native_ml-kem-512_aarch64/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-512_aarch64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-512_aarch64/sys.h | 109 ++ .../mlkem-native_ml-kem-512_aarch64/verify.c | 20 + .../mlkem-native_ml-kem-512_aarch64/verify.h | 317 ++++ .../mlkem-native_ml-kem-512_aarch64/zetas.c | 30 + .../LICENSE | 0 .../ml_kem/mlkem-native_ml-kem-512_ref/api.h | 255 +++ .../arith_backend.h | 22 + .../ml_kem/mlkem-native_ml-kem-512_ref/cbd.c | 156 ++ .../ml_kem/mlkem-native_ml-kem-512_ref/cbd.h | 54 + .../ml_kem/mlkem-native_ml-kem-512_ref/cbmc.h | 139 ++ .../mlkem-native_ml-kem-512_ref/common.h | 65 + .../mlkem-native_ml-kem-512_ref/config.h | 144 ++ .../mlkem-native_ml-kem-512_ref/debug/debug.c | 56 + .../mlkem-native_ml-kem-512_ref/debug/debug.h | 224 +++ .../mlkem-native_ml-kem-512_ref/default.h | 32 + .../mlkem-native_ml-kem-512_ref/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-512_ref/indcpa.h | 117 ++ .../ml_kem/mlkem-native_ml-kem-512_ref/kem.c | 195 ++ .../ml_kem/mlkem-native_ml-kem-512_ref/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../ml_kem/mlkem-native_ml-kem-512_ref/ntt.c | 268 +++ .../ml_kem/mlkem-native_ml-kem-512_ref/ntt.h | 103 ++ .../mlkem-native_ml-kem-512_ref/params.h | 64 + .../ml_kem/mlkem-native_ml-kem-512_ref/poly.c | 583 ++++++ .../ml_kem/mlkem-native_ml-kem-512_ref/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-512_ref/polyvec.c | 172 ++ .../mlkem-native_ml-kem-512_ref/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-512_ref/reduce.h | 206 +++ .../mlkem-native_ml-kem-512_ref/rej_uniform.c | 106 ++ .../mlkem-native_ml-kem-512_ref/rej_uniform.h | 62 + .../mlkem-native_ml-kem-512_ref/symmetric.h | 52 + .../ml_kem/mlkem-native_ml-kem-512_ref/sys.h | 109 ++ .../mlkem-native_ml-kem-512_ref/verify.c | 20 + .../mlkem-native_ml-kem-512_ref/verify.h | 317 ++++ .../mlkem-native_ml-kem-512_ref/zetas.c | 30 + .../LICENSE | 0 .../mlkem-native_ml-kem-512_x86_64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-512_x86_64/cbd.c | 156 ++ .../mlkem-native_ml-kem-512_x86_64/cbd.h | 54 + .../mlkem-native_ml-kem-512_x86_64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-512_x86_64/common.h | 65 + .../mlkem-native_ml-kem-512_x86_64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-512_x86_64/default.h | 32 + .../mlkem-native_ml-kem-512_x86_64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-512_x86_64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-512_x86_64/kem.c | 195 ++ .../mlkem-native_ml-kem-512_x86_64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-512_x86_64/ntt.c | 268 +++ .../mlkem-native_ml-kem-512_x86_64/ntt.h | 103 ++ .../mlkem-native_ml-kem-512_x86_64/params.h | 64 + .../mlkem-native_ml-kem-512_x86_64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-512_x86_64/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-512_x86_64/polyvec.c | 172 ++ .../mlkem-native_ml-kem-512_x86_64/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-512_x86_64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-512_x86_64/sys.h | 109 ++ .../mlkem-native_ml-kem-512_x86_64/verify.c | 20 + .../mlkem-native_ml-kem-512_x86_64/verify.h | 317 ++++ .../x86_64/README.md | 4 + .../x86_64/default.h | 24 + .../x86_64/src/align.h | 31 + .../x86_64/src/arith_native_x86_64.h | 59 + .../x86_64/src}/basemul.S | 41 +- .../x86_64/src/basemul.c | 68 + .../x86_64/src/consts.c | 93 + .../x86_64/src/consts.h | 44 + .../x86_64/src/default_impl.h | 97 + .../x86_64/src}/fq.S | 61 +- .../x86_64/src}/fq.inc | 12 + .../x86_64/src/intt.S | 255 +++ .../x86_64/src/ntt.S | 219 +++ .../x86_64/src/rej_uniform_avx2.c | 131 ++ .../x86_64/src/rej_uniform_table.c | 159 ++ .../x86_64/src}/shuffle.S | 56 +- .../x86_64/src}/shuffle.inc | 16 +- .../x86_64/src/x86_64_zetas.i | 56 + .../mlkem-native_ml-kem-512_x86_64/zetas.c | 30 + .../mlkem-native_ml-kem-768_aarch64/LICENSE | 6 + .../aarch64/README.md | 19 + .../aarch64/clean.h | 24 + .../aarch64/opt.h | 24 + .../aarch64/src/aarch64_zetas.c | 175 ++ .../aarch64/src/arith_native_aarch64.h | 90 + .../aarch64/src/clean_impl.h | 80 + .../aarch64/src/consts.h | 19 + .../aarch64/src/intt_clean.S | 364 ++++ .../aarch64/src/intt_opt.S | 1020 +++++++++++ .../aarch64/src/ntt_clean.S | 283 +++ .../aarch64/src/ntt_opt.S | 919 ++++++++++ .../aarch64/src/opt_impl.h | 81 + .../aarch64/src/optimize.sh | 121 ++ .../aarch64/src/poly_clean.S | 331 ++++ .../aarch64/src/poly_opt.S | 690 +++++++ .../aarch64/src/polyvec_clean.S | 288 +++ .../aarch64/src/polyvec_opt.S | 1584 +++++++++++++++++ .../aarch64/src/rej_uniform_asm_clean.S | 341 ++++ .../aarch64/src/rej_uniform_table.c | 288 +++ .../mlkem-native_ml-kem-768_aarch64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-768_aarch64/cbd.c | 156 ++ .../mlkem-native_ml-kem-768_aarch64/cbd.h | 54 + .../mlkem-native_ml-kem-768_aarch64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-768_aarch64/common.h | 65 + .../mlkem-native_ml-kem-768_aarch64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-768_aarch64/default.h | 32 + .../mlkem-native_ml-kem-768_aarch64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-768_aarch64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-768_aarch64/kem.c | 195 ++ .../mlkem-native_ml-kem-768_aarch64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-768_aarch64/ntt.c | 268 +++ .../mlkem-native_ml-kem-768_aarch64/ntt.h | 103 ++ .../mlkem-native_ml-kem-768_aarch64/params.h | 64 + .../mlkem-native_ml-kem-768_aarch64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-768_aarch64/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-768_aarch64/polyvec.c | 172 ++ .../mlkem-native_ml-kem-768_aarch64/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-768_aarch64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-768_aarch64/sys.h | 109 ++ .../mlkem-native_ml-kem-768_aarch64/verify.c | 20 + .../mlkem-native_ml-kem-768_aarch64/verify.h | 317 ++++ .../mlkem-native_ml-kem-768_aarch64/zetas.c | 30 + .../mlkem-native_ml-kem-768_ref/LICENSE | 6 + .../ml_kem/mlkem-native_ml-kem-768_ref/api.h | 255 +++ .../arith_backend.h | 22 + .../ml_kem/mlkem-native_ml-kem-768_ref/cbd.c | 156 ++ .../ml_kem/mlkem-native_ml-kem-768_ref/cbd.h | 54 + .../ml_kem/mlkem-native_ml-kem-768_ref/cbmc.h | 139 ++ .../mlkem-native_ml-kem-768_ref/common.h | 65 + .../mlkem-native_ml-kem-768_ref/config.h | 144 ++ .../mlkem-native_ml-kem-768_ref/debug/debug.c | 56 + .../mlkem-native_ml-kem-768_ref/debug/debug.h | 224 +++ .../mlkem-native_ml-kem-768_ref/default.h | 32 + .../mlkem-native_ml-kem-768_ref/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-768_ref/indcpa.h | 117 ++ .../ml_kem/mlkem-native_ml-kem-768_ref/kem.c | 195 ++ .../ml_kem/mlkem-native_ml-kem-768_ref/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../ml_kem/mlkem-native_ml-kem-768_ref/ntt.c | 268 +++ .../ml_kem/mlkem-native_ml-kem-768_ref/ntt.h | 103 ++ .../mlkem-native_ml-kem-768_ref/params.h | 64 + .../ml_kem/mlkem-native_ml-kem-768_ref/poly.c | 583 ++++++ .../ml_kem/mlkem-native_ml-kem-768_ref/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-768_ref/polyvec.c | 172 ++ .../mlkem-native_ml-kem-768_ref/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-768_ref/reduce.h | 206 +++ .../mlkem-native_ml-kem-768_ref/rej_uniform.c | 106 ++ .../mlkem-native_ml-kem-768_ref/rej_uniform.h | 62 + .../mlkem-native_ml-kem-768_ref/symmetric.h | 52 + .../ml_kem/mlkem-native_ml-kem-768_ref/sys.h | 109 ++ .../mlkem-native_ml-kem-768_ref/verify.c | 20 + .../mlkem-native_ml-kem-768_ref/verify.h | 317 ++++ .../mlkem-native_ml-kem-768_ref/zetas.c | 30 + .../mlkem-native_ml-kem-768_x86_64/LICENSE | 6 + .../mlkem-native_ml-kem-768_x86_64/api.h | 255 +++ .../arith_backend.h | 22 + .../mlkem-native_ml-kem-768_x86_64/cbd.c | 156 ++ .../mlkem-native_ml-kem-768_x86_64/cbd.h | 54 + .../mlkem-native_ml-kem-768_x86_64/cbmc.h | 139 ++ .../mlkem-native_ml-kem-768_x86_64/common.h | 65 + .../mlkem-native_ml-kem-768_x86_64/config.h | 144 ++ .../debug/debug.c | 56 + .../debug/debug.h | 224 +++ .../mlkem-native_ml-kem-768_x86_64/default.h | 32 + .../mlkem-native_ml-kem-768_x86_64/indcpa.c | 559 ++++++ .../mlkem-native_ml-kem-768_x86_64/indcpa.h | 117 ++ .../mlkem-native_ml-kem-768_x86_64/kem.c | 195 ++ .../mlkem-native_ml-kem-768_x86_64/kem.h | 174 ++ .../mlkem_native.h | 241 +++ .../mlkem-native_ml-kem-768_x86_64/ntt.c | 268 +++ .../mlkem-native_ml-kem-768_x86_64/ntt.h | 103 ++ .../mlkem-native_ml-kem-768_x86_64/params.h | 64 + .../mlkem-native_ml-kem-768_x86_64/poly.c | 583 ++++++ .../mlkem-native_ml-kem-768_x86_64/poly.h | 805 +++++++++ .../mlkem-native_ml-kem-768_x86_64/polyvec.c | 172 ++ .../mlkem-native_ml-kem-768_x86_64/polyvec.h | 332 ++++ .../mlkem-native_ml-kem-768_x86_64/reduce.h | 206 +++ .../rej_uniform.c | 106 ++ .../rej_uniform.h | 62 + .../symmetric.h | 52 + .../mlkem-native_ml-kem-768_x86_64/sys.h | 109 ++ .../mlkem-native_ml-kem-768_x86_64/verify.c | 20 + .../mlkem-native_ml-kem-768_x86_64/verify.h | 317 ++++ .../x86_64/README.md | 4 + .../x86_64/default.h | 24 + .../x86_64/src/align.h | 31 + .../x86_64/src/arith_native_x86_64.h | 59 + .../x86_64/src}/basemul.S | 41 +- .../x86_64/src/basemul.c | 68 + .../x86_64/src/consts.c | 93 + .../x86_64/src/consts.h | 44 + .../x86_64/src/default_impl.h | 97 + .../x86_64/src}/fq.S | 61 +- .../x86_64/src}/fq.inc | 12 + .../x86_64/src/intt.S | 255 +++ .../x86_64/src/ntt.S | 219 +++ .../x86_64/src/rej_uniform_avx2.c | 131 ++ .../x86_64/src/rej_uniform_table.c | 159 ++ .../x86_64/src}/shuffle.S | 56 +- .../x86_64/src}/shuffle.inc | 16 +- .../x86_64/src/x86_64_zetas.i | 56 + .../mlkem-native_ml-kem-768_x86_64/zetas.c | 30 + .../align.h | 19 - .../api.h | 66 - .../cbd.c | 144 -- .../cbd.h | 15 - .../consts.c | 121 -- .../consts.h | 43 - .../indcpa.c | 568 ------ .../indcpa.h | 27 - .../invntt.S | 193 -- .../kem.c | 169 -- .../kem.h | 35 - .../ntt.S | 189 -- .../ntt.h | 28 - .../params.h | 68 - .../poly.c | 519 ------ .../poly.h | 77 - .../polyvec.c | 307 ---- .../polyvec.h | 36 - .../reduce.h | 12 - .../rejsample.c | 398 ----- .../rejsample.h | 14 - .../symmetric-shake.c | 74 - .../symmetric.h | 34 - .../verify.c | 83 - .../verify.h | 17 - .../api.h | 66 - .../cbd.c | 128 -- .../cbd.h | 14 - .../indcpa.c | 334 ---- .../indcpa.h | 27 - .../kem.c | 169 -- .../kem.h | 35 - .../ntt.c | 146 -- .../ntt.h | 19 - .../params.h | 55 - .../poly.c | 360 ---- .../poly.h | 53 - .../polyvec.c | 246 --- .../polyvec.h | 36 - .../reduce.c | 42 - .../reduce.h | 16 - .../symmetric-shake.c | 74 - .../symmetric.h | 35 - .../verify.c | 75 - .../verify.h | 17 - .../align.h | 19 - .../api.h | 66 - .../cbd.c | 144 -- .../cbd.h | 15 - .../consts.c | 121 -- .../consts.h | 43 - .../indcpa.c | 568 ------ .../indcpa.h | 27 - .../invntt.S | 193 -- .../kem.c | 169 -- .../kem.h | 35 - .../ntt.S | 189 -- .../ntt.h | 28 - .../params.h | 68 - .../poly.c | 519 ------ .../poly.h | 77 - .../polyvec.c | 307 ---- .../polyvec.h | 36 - .../reduce.h | 12 - .../rejsample.c | 398 ----- .../rejsample.h | 14 - .../symmetric-shake.c | 74 - .../symmetric.h | 34 - .../verify.c | 83 - .../verify.h | 17 - .../api.h | 66 - .../cbd.c | 128 -- .../cbd.h | 14 - .../indcpa.c | 334 ---- .../indcpa.h | 27 - .../kem.c | 169 -- .../kem.h | 35 - .../ntt.c | 146 -- .../ntt.h | 19 - .../params.h | 55 - .../poly.c | 360 ---- .../poly.h | 53 - .../polyvec.c | 246 --- .../polyvec.h | 36 - .../reduce.c | 42 - .../reduce.h | 16 - .../symmetric-shake.c | 74 - .../symmetric.h | 35 - .../verify.c | 75 - .../verify.h | 17 - .../align.h | 19 - .../api.h | 66 - .../cbd.c | 144 -- .../cbd.h | 15 - .../consts.c | 121 -- .../consts.h | 43 - .../indcpa.c | 568 ------ .../indcpa.h | 27 - .../invntt.S | 193 -- .../kem.c | 169 -- .../kem.h | 35 - .../ntt.S | 189 -- .../ntt.h | 28 - .../params.h | 68 - .../poly.c | 519 ------ .../poly.h | 77 - .../polyvec.c | 307 ---- .../polyvec.h | 36 - .../reduce.h | 12 - .../rejsample.c | 398 ----- .../rejsample.h | 14 - .../symmetric-shake.c | 74 - .../symmetric.h | 34 - .../verify.c | 83 - .../verify.h | 17 - .../api.h | 66 - .../cbd.c | 128 -- .../cbd.h | 14 - .../indcpa.c | 334 ---- .../indcpa.h | 27 - .../kem.c | 169 -- .../kem.h | 35 - .../ntt.c | 146 -- .../ntt.h | 19 - .../params.h | 55 - .../poly.c | 360 ---- .../poly.h | 53 - .../polyvec.c | 246 --- .../polyvec.h | 36 - .../reduce.c | 42 - .../reduce.h | 16 - .../symmetric-shake.c | 74 - .../symmetric.h | 35 - .../verify.c | 75 - .../verify.h | 17 - src/oqsconfig.h.cmake | 9 +- tests/test_binary.py | 2 +- 537 files changed, 75605 insertions(+), 16007 deletions(-) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-1024_aarch64}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/clean.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/opt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/aarch64_zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/arith_native_aarch64.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/clean_impl.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/opt_impl.h create mode 100755 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/optimize.sh create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_asm_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_table.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/zetas.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_ref => mlkem-native_ml-kem-1024_ref}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/zetas.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-1024_x86_64}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/align.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/arith_native_x86_64.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-1024_x86_64/x86_64/src}/basemul.S (61%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/default_impl.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-1024_x86_64/x86_64/src}/fq.S (50%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-1024_x86_64/x86_64/src}/fq.inc (67%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/intt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/ntt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_avx2.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_table.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-1024_x86_64/x86_64/src}/shuffle.S (81%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-1024_x86_64/x86_64/src}/shuffle.inc (55%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/x86_64_zetas.i create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/zetas.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_ref => mlkem-native_ml-kem-512_aarch64}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/clean.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/opt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/aarch64_zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/arith_native_aarch64.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/clean_impl.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/opt_impl.h create mode 100755 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/optimize.sh create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_asm_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_table.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/zetas.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-512_ref}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_ref/zetas.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_ref => mlkem-native_ml-kem-512_x86_64}/LICENSE (100%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/align.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/arith_native_x86_64.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-512_x86_64/x86_64/src}/basemul.S (61%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/default_impl.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-512_x86_64/x86_64/src}/fq.S (50%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-512_x86_64/x86_64/src}/fq.inc (67%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/intt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/ntt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_avx2.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_table.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-512_x86_64/x86_64/src}/shuffle.S (81%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-512_x86_64/x86_64/src}/shuffle.inc (55%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/x86_64_zetas.i create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/LICENSE create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/clean.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/opt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/aarch64_zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/arith_native_aarch64.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/clean_impl.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/opt_impl.h create mode 100755 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/optimize.sh create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_opt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_asm_clean.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_table.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/LICENSE create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_ref/zetas.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/LICENSE create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/api.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/arith_backend.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbmc.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/common.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/config.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/mlkem_native.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/params.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/reduce.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/symmetric.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/sys.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/README.md create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/default.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/align.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/arith_native_x86_64.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-768_x86_64/x86_64/src}/basemul.S (61%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.h create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/default_impl.h rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-768_x86_64/x86_64/src}/fq.S (50%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-1024_avx2 => mlkem-native_ml-kem-768_x86_64/x86_64/src}/fq.inc (67%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/intt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/ntt.S create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_avx2.c create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_table.c rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-768_avx2 => mlkem-native_ml-kem-768_x86_64/x86_64/src}/shuffle.S (81%) rename src/kem/ml_kem/{pqcrystals-kyber-standard_ml-kem-512_avx2 => mlkem-native_ml-kem-768_x86_64/x86_64/src}/shuffle.inc (55%) create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/x86_64_zetas.i create mode 100644 src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/zetas.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/align.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/invntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/verify.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/align.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/invntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/verify.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/verify.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/align.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/invntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.S delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/verify.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/api.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/params.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric-shake.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric.h delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/verify.c delete mode 100644 src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/verify.h diff --git a/.CMake/alg_support.cmake b/.CMake/alg_support.cmake index 9afa6e4b15..37b47e5d7c 100644 --- a/.CMake/alg_support.cmake +++ b/.CMake/alg_support.cmake @@ -334,19 +334,43 @@ endif() if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") if(OQS_DIST_X86_64_BUILD OR (OQS_USE_AVX2_INSTRUCTIONS AND OQS_USE_BMI2_INSTRUCTIONS AND OQS_USE_POPCNT_INSTRUCTIONS)) - cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_512_avx2 "" ON "OQS_ENABLE_KEM_ml_kem_512" OFF) + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_512_x86_64 "" ON "OQS_ENABLE_KEM_ml_kem_512" OFF) +endif() +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") +if((OQS_DIST_ARM64_V8_BUILD OR (OQS_USE_ARM_NEON_INSTRUCTIONS AND OQS_USE_ARM_NEON_INSTRUCTIONS))) + + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_512_aarch64 "" ON "OQS_ENABLE_KEM_ml_kem_512" OFF) + endif() endif() if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") if(OQS_DIST_X86_64_BUILD OR (OQS_USE_AVX2_INSTRUCTIONS AND OQS_USE_BMI2_INSTRUCTIONS AND OQS_USE_POPCNT_INSTRUCTIONS)) - cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_768_avx2 "" ON "OQS_ENABLE_KEM_ml_kem_768" OFF) + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_768_x86_64 "" ON "OQS_ENABLE_KEM_ml_kem_768" OFF) +endif() +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") +if((OQS_DIST_ARM64_V8_BUILD OR (OQS_USE_ARM_NEON_INSTRUCTIONS AND OQS_USE_ARM_NEON_INSTRUCTIONS))) + + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_768_aarch64 "" ON "OQS_ENABLE_KEM_ml_kem_768" OFF) + endif() endif() if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") if(OQS_DIST_X86_64_BUILD OR (OQS_USE_AVX2_INSTRUCTIONS AND OQS_USE_BMI2_INSTRUCTIONS AND OQS_USE_POPCNT_INSTRUCTIONS)) - cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_1024_avx2 "" ON "OQS_ENABLE_KEM_ml_kem_1024" OFF) + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_1024_x86_64 "" ON "OQS_ENABLE_KEM_ml_kem_1024" OFF) +endif() +endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Linux|Darwin") +if((OQS_DIST_ARM64_V8_BUILD OR (OQS_USE_ARM_NEON_INSTRUCTIONS AND OQS_USE_ARM_NEON_INSTRUCTIONS))) + + cmake_dependent_option(OQS_ENABLE_KEM_ml_kem_1024_aarch64 "" ON "OQS_ENABLE_KEM_ml_kem_1024" OFF) + endif() endif() diff --git a/docs/algorithms/kem/ml_kem.md b/docs/algorithms/kem/ml_kem.md index d1806517ba..fc6afdb9ed 100644 --- a/docs/algorithms/kem/ml_kem.md +++ b/docs/algorithms/kem/ml_kem.md @@ -7,7 +7,7 @@ - **Authors' website**: https://pq-crystals.org/kyber/ and https://csrc.nist.gov/pubs/fips/203 - **Specification version**: ML-KEM. - **Primary Source**: - - **Source**: https://github.com/pq-crystals/kyber/commit/10b478fc3cc4ff6215eb0b6a11bd758bf0929cbd with copy_from_upstream patches + - **Source**: https://github.com/bhess/mlkem-native/commit/86cc8d0cd3af1dff13228296cbadbbfc6928890c - **Implementation license (SPDX-Identifier)**: CC0-1.0 or Apache-2.0 @@ -24,7 +24,6 @@ | Implementation source | Identifier in upstream | Supported architecture(s) | Supported operating system(s) | CPU extension(s) used | No branching-on-secrets claimed? | No branching-on-secrets checked by valgrind? | Large stack usage?‡ | |:---------------------------------:|:-------------------------|:----------------------------|:--------------------------------|:------------------------|:-----------------------------------|:-----------------------------------------------|:----------------------| | [Primary Source](#primary-source) | ref | All | All | None | True | True | False | -| [Primary Source](#primary-source) | avx2 | x86\_64 | Linux,Darwin | AVX2,BMI2,POPCNT | True | True | False | Are implementations chosen based on runtime CPU feature detection? **Yes**. @@ -35,7 +34,6 @@ Are implementations chosen based on runtime CPU feature detection? **Yes**. | Implementation source | Identifier in upstream | Supported architecture(s) | Supported operating system(s) | CPU extension(s) used | No branching-on-secrets claimed? | No branching-on-secrets checked by valgrind? | Large stack usage? | |:---------------------------------:|:-------------------------|:----------------------------|:--------------------------------|:------------------------|:-----------------------------------|:-----------------------------------------------|:---------------------| | [Primary Source](#primary-source) | ref | All | All | None | True | True | False | -| [Primary Source](#primary-source) | avx2 | x86\_64 | Linux,Darwin | AVX2,BMI2,POPCNT | True | True | False | Are implementations chosen based on runtime CPU feature detection? **Yes**. @@ -44,7 +42,6 @@ Are implementations chosen based on runtime CPU feature detection? **Yes**. | Implementation source | Identifier in upstream | Supported architecture(s) | Supported operating system(s) | CPU extension(s) used | No branching-on-secrets claimed? | No branching-on-secrets checked by valgrind? | Large stack usage? | |:---------------------------------:|:-------------------------|:----------------------------|:--------------------------------|:------------------------|:-----------------------------------|:-----------------------------------------------|:---------------------| | [Primary Source](#primary-source) | ref | All | All | None | True | True | False | -| [Primary Source](#primary-source) | avx2 | x86\_64 | Linux,Darwin | AVX2,BMI2,POPCNT | True | True | False | Are implementations chosen based on runtime CPU feature detection? **Yes**. diff --git a/docs/algorithms/kem/ml_kem.yml b/docs/algorithms/kem/ml_kem.yml index 81ef2b6c4a..62e533d779 100644 --- a/docs/algorithms/kem/ml_kem.yml +++ b/docs/algorithms/kem/ml_kem.yml @@ -17,8 +17,7 @@ website: https://pq-crystals.org/kyber/ and https://csrc.nist.gov/pubs/fips/203 nist-round: FIPS203 spec-version: ML-KEM primary-upstream: - source: https://github.com/pq-crystals/kyber/commit/10b478fc3cc4ff6215eb0b6a11bd758bf0929cbd - with copy_from_upstream patches + source: https://github.com/bhess/mlkem-native/commit/86cc8d0cd3af1dff13228296cbadbbfc6928890c spdx-license-identifier: CC0-1.0 or Apache-2.0 parameter-sets: - name: ML-KEM-512 @@ -38,22 +37,6 @@ parameter-sets: no-secret-dependent-branching-claimed: true no-secret-dependent-branching-checked-by-valgrind: true large-stack-usage: false - - upstream: primary-upstream - upstream-id: avx2 - supported-platforms: - - architecture: x86_64 - operating_systems: - - Linux - - Darwin - required_flags: - - avx2 - - bmi2 - - popcnt - common-crypto: - - SHA3: liboqs - no-secret-dependent-branching-claimed: true - no-secret-dependent-branching-checked-by-valgrind: true - large-stack-usage: false - name: ML-KEM-768 claimed-nist-level: 3 claimed-security: IND-CCA2 @@ -71,22 +54,6 @@ parameter-sets: no-secret-dependent-branching-claimed: true no-secret-dependent-branching-checked-by-valgrind: true large-stack-usage: false - - upstream: primary-upstream - upstream-id: avx2 - supported-platforms: - - architecture: x86_64 - operating_systems: - - Linux - - Darwin - required_flags: - - avx2 - - bmi2 - - popcnt - common-crypto: - - SHA3: liboqs - no-secret-dependent-branching-claimed: true - no-secret-dependent-branching-checked-by-valgrind: true - large-stack-usage: false - name: ML-KEM-1024 claimed-nist-level: 5 claimed-security: IND-CCA2 @@ -104,19 +71,3 @@ parameter-sets: no-secret-dependent-branching-claimed: true no-secret-dependent-branching-checked-by-valgrind: true large-stack-usage: false - - upstream: primary-upstream - upstream-id: avx2 - supported-platforms: - - architecture: x86_64 - operating_systems: - - Linux - - Darwin - required_flags: - - avx2 - - bmi2 - - popcnt - common-crypto: - - SHA3: liboqs - no-secret-dependent-branching-claimed: true - no-secret-dependent-branching-checked-by-valgrind: true - large-stack-usage: false diff --git a/docs/cbom.json b/docs/cbom.json index 52cf0a0a59..a9361e3756 100644 --- a/docs/cbom.json +++ b/docs/cbom.json @@ -2,23 +2,23 @@ "$schema": "https://raw.githubusercontent.com/CycloneDX/specification/1.6/schema/bom-1.6.schema.json", "bomFormat": "CycloneDX", "specVersion": "1.6", - "serialNumber": "urn:uuid:de1355bb-9681-4a7e-8aa9-0ccc414ebe3b", + "serialNumber": "urn:uuid:d66add05-17dd-4986-8894-ed47d1e910b6", "version": 1, "metadata": { - "timestamp": "2024-11-05T12:25:53.012740+00:00", + "timestamp": "2024-12-09T14:24:28.343759+00:00", "component": { "type": "library", - "bom-ref": "pkg:github/open-quantum-safe/liboqs@69a80f8a66988521d51e94d716cff8c936c07b8d", + "bom-ref": "pkg:github/open-quantum-safe/liboqs@d0d0413dc9fff538296ab86bac492cb4bf54dedb", "name": "liboqs", - "version": "69a80f8a66988521d51e94d716cff8c936c07b8d" + "version": "d0d0413dc9fff538296ab86bac492cb4bf54dedb" } }, "components": [ { "type": "library", - "bom-ref": "pkg:github/open-quantum-safe/liboqs@69a80f8a66988521d51e94d716cff8c936c07b8d", + "bom-ref": "pkg:github/open-quantum-safe/liboqs@d0d0413dc9fff538296ab86bac492cb4bf54dedb", "name": "liboqs", - "version": "69a80f8a66988521d51e94d716cff8c936c07b8d" + "version": "d0d0413dc9fff538296ab86bac492cb4bf54dedb" }, { "type": "cryptographic-asset", @@ -1060,26 +1060,6 @@ } } }, - { - "type": "cryptographic-asset", - "bom-ref": "alg:ML-KEM-512:x86_64", - "name": "ML-KEM", - "cryptoProperties": { - "assetType": "algorithm", - "algorithmProperties": { - "parameterSetIdentifier": "ML-KEM-512", - "primitive": "kem", - "executionEnvironment": "software-plain-ram", - "cryptoFunctions": [ - "keygen", - "encapsulate", - "decapsulate" - ], - "nistQuantumSecurityLevel": 1, - "implementationPlatform": "x86_64" - } - } - }, { "type": "cryptographic-asset", "bom-ref": "alg:ML-KEM-768:generic", @@ -1100,26 +1080,6 @@ } } }, - { - "type": "cryptographic-asset", - "bom-ref": "alg:ML-KEM-768:x86_64", - "name": "ML-KEM", - "cryptoProperties": { - "assetType": "algorithm", - "algorithmProperties": { - "parameterSetIdentifier": "ML-KEM-768", - "primitive": "kem", - "executionEnvironment": "software-plain-ram", - "cryptoFunctions": [ - "keygen", - "encapsulate", - "decapsulate" - ], - "nistQuantumSecurityLevel": 3, - "implementationPlatform": "x86_64" - } - } - }, { "type": "cryptographic-asset", "bom-ref": "alg:ML-KEM-1024:generic", @@ -1140,26 +1100,6 @@ } } }, - { - "type": "cryptographic-asset", - "bom-ref": "alg:ML-KEM-1024:x86_64", - "name": "ML-KEM", - "cryptoProperties": { - "assetType": "algorithm", - "algorithmProperties": { - "parameterSetIdentifier": "ML-KEM-1024", - "primitive": "kem", - "executionEnvironment": "software-plain-ram", - "cryptoFunctions": [ - "keygen", - "encapsulate", - "decapsulate" - ], - "nistQuantumSecurityLevel": 5, - "implementationPlatform": "x86_64" - } - } - }, { "type": "cryptographic-asset", "bom-ref": "alg:sntrup761:generic", @@ -3127,7 +3067,7 @@ ], "dependencies": [ { - "ref": "pkg:github/open-quantum-safe/liboqs@69a80f8a66988521d51e94d716cff8c936c07b8d", + "ref": "pkg:github/open-quantum-safe/liboqs@d0d0413dc9fff538296ab86bac492cb4bf54dedb", "provides": [ "alg:BIKE-L1:x86_64", "alg:BIKE-L3:x86_64", @@ -3181,11 +3121,8 @@ "alg:Kyber1024:x86_64", "alg:Kyber1024:armv8-a", "alg:ML-KEM-512:generic", - "alg:ML-KEM-512:x86_64", "alg:ML-KEM-768:generic", - "alg:ML-KEM-768:x86_64", "alg:ML-KEM-1024:generic", - "alg:ML-KEM-1024:x86_64", "alg:sntrup761:generic", "alg:sntrup761:x86_64", "alg:cross-rsdp-128-balanced:generic", @@ -3605,36 +3542,18 @@ "alg:sha3" ] }, - { - "ref": "alg:ML-KEM-512:x86_64", - "dependsOn": [ - "alg:sha3" - ] - }, { "ref": "alg:ML-KEM-768:generic", "dependsOn": [ "alg:sha3" ] }, - { - "ref": "alg:ML-KEM-768:x86_64", - "dependsOn": [ - "alg:sha3" - ] - }, { "ref": "alg:ML-KEM-1024:generic", "dependsOn": [ "alg:sha3" ] }, - { - "ref": "alg:ML-KEM-1024:x86_64", - "dependsOn": [ - "alg:sha3" - ] - }, { "ref": "alg:sntrup761:generic", "dependsOn": [ diff --git a/scripts/copy_from_upstream/copy_from_upstream.yml b/scripts/copy_from_upstream/copy_from_upstream.yml index f80f0979d5..0f22dda1fd 100644 --- a/scripts/copy_from_upstream/copy_from_upstream.yml +++ b/scripts/copy_from_upstream/copy_from_upstream.yml @@ -31,13 +31,12 @@ upstreams: kem_scheme_path: '.' patches: [pqcrystals-kyber-yml.patch, pqcrystals-kyber-ref-shake-aes.patch, pqcrystals-kyber-avx2-shake-aes.patch] - - name: pqcrystals-kyber-standard - git_url: https://github.com/pq-crystals/kyber.git - git_branch: main - git_commit: 10b478fc3cc4ff6215eb0b6a11bd758bf0929cbd + name: mlkem-native + git_url: https://github.com/bhess/mlkem-native.git + git_branch: updates-5 + git_commit: 86cc8d0cd3af1dff13228296cbadbbfc6928890c kem_meta_path: '{pretty_name_full}_META.yml' kem_scheme_path: '.' - patches: [pqcrystals-ml_kem.patch] - name: pqcrystals-dilithium git_url: https://github.com/pq-crystals/dilithium.git @@ -166,7 +165,7 @@ kems: - name: ml_kem default_implementation: ref - upstream_location: pqcrystals-kyber-standard + upstream_location: mlkem-native schemes: - scheme: "512" diff --git a/src/common/pqclean_shims/fips202x4.h b/src/common/pqclean_shims/fips202x4.h index c1f7ffcf0e..a1beda0c99 100644 --- a/src/common/pqclean_shims/fips202x4.h +++ b/src/common/pqclean_shims/fips202x4.h @@ -48,4 +48,9 @@ void OQS_SHA3_shake256_x4_absorb_once(shake256x4incctx *state, const uint8_t *in #define shake256x4_squeezeblocks(OUT0, OUT1, OUT2, OUT3, NBLOCKS, STATE) \ OQS_SHA3_shake256_x4_inc_squeeze(OUT0, OUT1, OUT2, OUT3, (NBLOCKS)*OQS_SHA3_SHAKE256_RATE, STATE) +#define shake128x4ctx shake128x4incctx +#define shake128x4_release shake128x4_inc_ctx_release +#define shake128ctx shake128incctx +#define shake128_release shake128_inc_ctx_release + #endif diff --git a/src/kem/ml_kem/CMakeLists.txt b/src/kem/ml_kem/CMakeLists.txt index 14cc9b850d..edd305ce88 100644 --- a/src/kem/ml_kem/CMakeLists.txt +++ b/src/kem/ml_kem/CMakeLists.txt @@ -6,57 +6,81 @@ set(_ML_KEM_OBJS "") if(OQS_ENABLE_KEM_ml_kem_512) - add_library(ml_kem_512_ref OBJECT kem_ml_kem_512.c pqcrystals-kyber-standard_ml-kem-512_ref/cbd.c pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.c pqcrystals-kyber-standard_ml-kem-512_ref/kem.c pqcrystals-kyber-standard_ml-kem-512_ref/ntt.c pqcrystals-kyber-standard_ml-kem-512_ref/poly.c pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.c pqcrystals-kyber-standard_ml-kem-512_ref/reduce.c pqcrystals-kyber-standard_ml-kem-512_ref/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-512_ref/verify.c) - target_compile_options(ml_kem_512_ref PUBLIC -DKYBER_K=2) - target_include_directories(ml_kem_512_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-512_ref) + add_library(ml_kem_512_ref OBJECT kem_ml_kem_512.c mlkem-native_ml-kem-512_ref/cbd.c mlkem-native_ml-kem-512_ref/debug/debug.c mlkem-native_ml-kem-512_ref/indcpa.c mlkem-native_ml-kem-512_ref/kem.c mlkem-native_ml-kem-512_ref/ntt.c mlkem-native_ml-kem-512_ref/poly.c mlkem-native_ml-kem-512_ref/polyvec.c mlkem-native_ml-kem-512_ref/rej_uniform.c mlkem-native_ml-kem-512_ref/verify.c mlkem-native_ml-kem-512_ref/zetas.c) + target_compile_options(ml_kem_512_ref PUBLIC -DMLKEM_K=2 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM512_C) + target_include_directories(ml_kem_512_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-512_ref) target_include_directories(ml_kem_512_ref PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_512_ref PUBLIC -DKYBER_K=2) + target_compile_options(ml_kem_512_ref PUBLIC -DMLKEM_K=2 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM512_C) set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() -if(OQS_ENABLE_KEM_ml_kem_512_avx2) - add_library(ml_kem_512_avx2 OBJECT pqcrystals-kyber-standard_ml-kem-512_avx2/basemul.S pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.c pqcrystals-kyber-standard_ml-kem-512_avx2/consts.c pqcrystals-kyber-standard_ml-kem-512_avx2/fq.S pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.c pqcrystals-kyber-standard_ml-kem-512_avx2/invntt.S pqcrystals-kyber-standard_ml-kem-512_avx2/kem.c pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.S pqcrystals-kyber-standard_ml-kem-512_avx2/poly.c pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.c pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.c pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.S pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-512_avx2/verify.c) - target_include_directories(ml_kem_512_avx2 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-512_avx2) - target_include_directories(ml_kem_512_avx2 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_512_avx2 PRIVATE -mavx2 -mbmi2 -mpopcnt ) - target_compile_options(ml_kem_512_avx2 PUBLIC -DKYBER_K=2) - set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +if(OQS_ENABLE_KEM_ml_kem_512_x86_64) + add_library(ml_kem_512_x86_64 OBJECT mlkem-native_ml-kem-512_x86_64/cbd.c mlkem-native_ml-kem-512_x86_64/debug/debug.c mlkem-native_ml-kem-512_x86_64/indcpa.c mlkem-native_ml-kem-512_x86_64/kem.c mlkem-native_ml-kem-512_x86_64/ntt.c mlkem-native_ml-kem-512_x86_64/poly.c mlkem-native_ml-kem-512_x86_64/polyvec.c mlkem-native_ml-kem-512_x86_64/rej_uniform.c mlkem-native_ml-kem-512_x86_64/verify.c mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.c mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.S mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.c mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.S mlkem-native_ml-kem-512_x86_64/x86_64/src/intt.S mlkem-native_ml-kem-512_x86_64/x86_64/src/ntt.S mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_avx2.c mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_table.c mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.S mlkem-native_ml-kem-512_x86_64/zetas.c) + target_include_directories(ml_kem_512_x86_64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-512_x86_64) + target_include_directories(ml_kem_512_x86_64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_512_x86_64 PRIVATE -mavx2 -mbmi2 -mpopcnt ) + target_compile_options(ml_kem_512_x86_64 PUBLIC -DMLKEM_K=2 -DFORCE_X86_64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=X86_64_DEFAULT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +endif() + +if(OQS_ENABLE_KEM_ml_kem_512_aarch64) + add_library(ml_kem_512_aarch64 OBJECT mlkem-native_ml-kem-512_aarch64/aarch64/src/aarch64_zetas.c mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_clean.S mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_opt.S mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_clean.S mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_opt.S mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_clean.S mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_opt.S mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_clean.S mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_opt.S mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_asm_clean.S mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_table.c mlkem-native_ml-kem-512_aarch64/cbd.c mlkem-native_ml-kem-512_aarch64/debug/debug.c mlkem-native_ml-kem-512_aarch64/indcpa.c mlkem-native_ml-kem-512_aarch64/kem.c mlkem-native_ml-kem-512_aarch64/ntt.c mlkem-native_ml-kem-512_aarch64/poly.c mlkem-native_ml-kem-512_aarch64/polyvec.c mlkem-native_ml-kem-512_aarch64/rej_uniform.c mlkem-native_ml-kem-512_aarch64/verify.c mlkem-native_ml-kem-512_aarch64/zetas.c) + target_include_directories(ml_kem_512_aarch64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-512_aarch64) + target_include_directories(ml_kem_512_aarch64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_512_aarch64 PUBLIC -DMLKEM_K=2 -DFORCE_AARCH64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=AARCH64_OPT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() if(OQS_ENABLE_KEM_ml_kem_768) - add_library(ml_kem_768_ref OBJECT kem_ml_kem_768.c pqcrystals-kyber-standard_ml-kem-768_ref/cbd.c pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.c pqcrystals-kyber-standard_ml-kem-768_ref/kem.c pqcrystals-kyber-standard_ml-kem-768_ref/ntt.c pqcrystals-kyber-standard_ml-kem-768_ref/poly.c pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.c pqcrystals-kyber-standard_ml-kem-768_ref/reduce.c pqcrystals-kyber-standard_ml-kem-768_ref/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-768_ref/verify.c) - target_compile_options(ml_kem_768_ref PUBLIC -DKYBER_K=3) - target_include_directories(ml_kem_768_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-768_ref) + add_library(ml_kem_768_ref OBJECT kem_ml_kem_768.c mlkem-native_ml-kem-768_ref/cbd.c mlkem-native_ml-kem-768_ref/debug/debug.c mlkem-native_ml-kem-768_ref/indcpa.c mlkem-native_ml-kem-768_ref/kem.c mlkem-native_ml-kem-768_ref/ntt.c mlkem-native_ml-kem-768_ref/poly.c mlkem-native_ml-kem-768_ref/polyvec.c mlkem-native_ml-kem-768_ref/rej_uniform.c mlkem-native_ml-kem-768_ref/verify.c mlkem-native_ml-kem-768_ref/zetas.c) + target_compile_options(ml_kem_768_ref PUBLIC -DMLKEM_K=3 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM768_C) + target_include_directories(ml_kem_768_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-768_ref) target_include_directories(ml_kem_768_ref PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_768_ref PUBLIC -DKYBER_K=3) + target_compile_options(ml_kem_768_ref PUBLIC -DMLKEM_K=3 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM768_C) set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() -if(OQS_ENABLE_KEM_ml_kem_768_avx2) - add_library(ml_kem_768_avx2 OBJECT pqcrystals-kyber-standard_ml-kem-768_avx2/basemul.S pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.c pqcrystals-kyber-standard_ml-kem-768_avx2/consts.c pqcrystals-kyber-standard_ml-kem-768_avx2/fq.S pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.c pqcrystals-kyber-standard_ml-kem-768_avx2/invntt.S pqcrystals-kyber-standard_ml-kem-768_avx2/kem.c pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.S pqcrystals-kyber-standard_ml-kem-768_avx2/poly.c pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.c pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.c pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.S pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-768_avx2/verify.c) - target_include_directories(ml_kem_768_avx2 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-768_avx2) - target_include_directories(ml_kem_768_avx2 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_768_avx2 PRIVATE -mavx2 -mbmi2 -mpopcnt ) - target_compile_options(ml_kem_768_avx2 PUBLIC -DKYBER_K=3) - set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +if(OQS_ENABLE_KEM_ml_kem_768_x86_64) + add_library(ml_kem_768_x86_64 OBJECT mlkem-native_ml-kem-768_x86_64/cbd.c mlkem-native_ml-kem-768_x86_64/debug/debug.c mlkem-native_ml-kem-768_x86_64/indcpa.c mlkem-native_ml-kem-768_x86_64/kem.c mlkem-native_ml-kem-768_x86_64/ntt.c mlkem-native_ml-kem-768_x86_64/poly.c mlkem-native_ml-kem-768_x86_64/polyvec.c mlkem-native_ml-kem-768_x86_64/rej_uniform.c mlkem-native_ml-kem-768_x86_64/verify.c mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.c mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.S mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.c mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.S mlkem-native_ml-kem-768_x86_64/x86_64/src/intt.S mlkem-native_ml-kem-768_x86_64/x86_64/src/ntt.S mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_avx2.c mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_table.c mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.S mlkem-native_ml-kem-768_x86_64/zetas.c) + target_include_directories(ml_kem_768_x86_64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-768_x86_64) + target_include_directories(ml_kem_768_x86_64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_768_x86_64 PRIVATE -mavx2 -mbmi2 -mpopcnt ) + target_compile_options(ml_kem_768_x86_64 PUBLIC -DMLKEM_K=3 -DFORCE_X86_64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=X86_64_DEFAULT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +endif() + +if(OQS_ENABLE_KEM_ml_kem_768_aarch64) + add_library(ml_kem_768_aarch64 OBJECT mlkem-native_ml-kem-768_aarch64/aarch64/src/aarch64_zetas.c mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_clean.S mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_opt.S mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_clean.S mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_opt.S mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_clean.S mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_opt.S mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_clean.S mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_opt.S mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_asm_clean.S mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_table.c mlkem-native_ml-kem-768_aarch64/cbd.c mlkem-native_ml-kem-768_aarch64/debug/debug.c mlkem-native_ml-kem-768_aarch64/indcpa.c mlkem-native_ml-kem-768_aarch64/kem.c mlkem-native_ml-kem-768_aarch64/ntt.c mlkem-native_ml-kem-768_aarch64/poly.c mlkem-native_ml-kem-768_aarch64/polyvec.c mlkem-native_ml-kem-768_aarch64/rej_uniform.c mlkem-native_ml-kem-768_aarch64/verify.c mlkem-native_ml-kem-768_aarch64/zetas.c) + target_include_directories(ml_kem_768_aarch64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-768_aarch64) + target_include_directories(ml_kem_768_aarch64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_768_aarch64 PUBLIC -DMLKEM_K=3 -DFORCE_AARCH64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=AARCH64_OPT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() if(OQS_ENABLE_KEM_ml_kem_1024) - add_library(ml_kem_1024_ref OBJECT kem_ml_kem_1024.c pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.c pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.c pqcrystals-kyber-standard_ml-kem-1024_ref/kem.c pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.c pqcrystals-kyber-standard_ml-kem-1024_ref/poly.c pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.c pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.c pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-1024_ref/verify.c) - target_compile_options(ml_kem_1024_ref PUBLIC -DKYBER_K=4) - target_include_directories(ml_kem_1024_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-1024_ref) + add_library(ml_kem_1024_ref OBJECT kem_ml_kem_1024.c mlkem-native_ml-kem-1024_ref/cbd.c mlkem-native_ml-kem-1024_ref/debug/debug.c mlkem-native_ml-kem-1024_ref/indcpa.c mlkem-native_ml-kem-1024_ref/kem.c mlkem-native_ml-kem-1024_ref/ntt.c mlkem-native_ml-kem-1024_ref/poly.c mlkem-native_ml-kem-1024_ref/polyvec.c mlkem-native_ml-kem-1024_ref/rej_uniform.c mlkem-native_ml-kem-1024_ref/verify.c mlkem-native_ml-kem-1024_ref/zetas.c) + target_compile_options(ml_kem_1024_ref PUBLIC -DMLKEM_K=4 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM1024_C) + target_include_directories(ml_kem_1024_ref PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-1024_ref) target_include_directories(ml_kem_1024_ref PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_1024_ref PUBLIC -DKYBER_K=4) + target_compile_options(ml_kem_1024_ref PUBLIC -DMLKEM_K=4 -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM1024_C) set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() -if(OQS_ENABLE_KEM_ml_kem_1024_avx2) - add_library(ml_kem_1024_avx2 OBJECT pqcrystals-kyber-standard_ml-kem-1024_avx2/basemul.S pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.c pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.c pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.S pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.c pqcrystals-kyber-standard_ml-kem-1024_avx2/invntt.S pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.c pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.S pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.c pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.c pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.c pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.S pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric-shake.c pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.c) - target_include_directories(ml_kem_1024_avx2 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/pqcrystals-kyber-standard_ml-kem-1024_avx2) - target_include_directories(ml_kem_1024_avx2 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) - target_compile_options(ml_kem_1024_avx2 PRIVATE -mavx2 -mbmi2 -mpopcnt ) - target_compile_options(ml_kem_1024_avx2 PUBLIC -DKYBER_K=4) - set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +if(OQS_ENABLE_KEM_ml_kem_1024_x86_64) + add_library(ml_kem_1024_x86_64 OBJECT mlkem-native_ml-kem-1024_x86_64/cbd.c mlkem-native_ml-kem-1024_x86_64/debug/debug.c mlkem-native_ml-kem-1024_x86_64/indcpa.c mlkem-native_ml-kem-1024_x86_64/kem.c mlkem-native_ml-kem-1024_x86_64/ntt.c mlkem-native_ml-kem-1024_x86_64/poly.c mlkem-native_ml-kem-1024_x86_64/polyvec.c mlkem-native_ml-kem-1024_x86_64/rej_uniform.c mlkem-native_ml-kem-1024_x86_64/verify.c mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.c mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.S mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.c mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.S mlkem-native_ml-kem-1024_x86_64/x86_64/src/intt.S mlkem-native_ml-kem-1024_x86_64/x86_64/src/ntt.S mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_avx2.c mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_table.c mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.S mlkem-native_ml-kem-1024_x86_64/zetas.c) + target_include_directories(ml_kem_1024_x86_64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-1024_x86_64) + target_include_directories(ml_kem_1024_x86_64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_1024_x86_64 PRIVATE -mavx2 -mbmi2 -mpopcnt ) + target_compile_options(ml_kem_1024_x86_64 PUBLIC -DMLKEM_K=4 -DFORCE_X86_64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=X86_64_DEFAULT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) +endif() + +if(OQS_ENABLE_KEM_ml_kem_1024_aarch64) + add_library(ml_kem_1024_aarch64 OBJECT mlkem-native_ml-kem-1024_aarch64/aarch64/src/aarch64_zetas.c mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_clean.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_opt.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_clean.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_opt.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_clean.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_opt.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_clean.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_opt.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_asm_clean.S mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_table.c mlkem-native_ml-kem-1024_aarch64/cbd.c mlkem-native_ml-kem-1024_aarch64/debug/debug.c mlkem-native_ml-kem-1024_aarch64/indcpa.c mlkem-native_ml-kem-1024_aarch64/kem.c mlkem-native_ml-kem-1024_aarch64/ntt.c mlkem-native_ml-kem-1024_aarch64/poly.c mlkem-native_ml-kem-1024_aarch64/polyvec.c mlkem-native_ml-kem-1024_aarch64/rej_uniform.c mlkem-native_ml-kem-1024_aarch64/verify.c mlkem-native_ml-kem-1024_aarch64/zetas.c) + target_include_directories(ml_kem_1024_aarch64 PRIVATE ${CMAKE_CURRENT_LIST_DIR}/mlkem-native_ml-kem-1024_aarch64) + target_include_directories(ml_kem_1024_aarch64 PRIVATE ${PROJECT_SOURCE_DIR}/src/common/pqclean_shims) + target_compile_options(ml_kem_1024_aarch64 PUBLIC -DMLKEM_K=4 -DFORCE_AARCH64 -DMLKEM_NATIVE_ARITH_BACKEND_NAME=AARCH64_OPT -DMLKEM_USE_NATIVE -DMLKEM_NAMESPACE_PREFIX=PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT) + set(_ML_KEM_OBJS ${_ML_KEM_OBJS} $) endif() set(ML_KEM_OBJS ${_ML_KEM_OBJS} PARENT_SCOPE) diff --git a/src/kem/ml_kem/kem_ml_kem_1024.c b/src/kem/ml_kem/kem_ml_kem_1024.c index bc533aef9e..21f746f963 100644 --- a/src/kem/ml_kem/kem_ml_kem_1024.c +++ b/src/kem/ml_kem/kem_ml_kem_1024.c @@ -30,61 +30,97 @@ OQS_KEM *OQS_KEM_ml_kem_1024_new(void) { return kem; } -extern int pqcrystals_ml_kem_1024_ref_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#if defined(OQS_ENABLE_KEM_ml_kem_1024_avx2) -extern int pqcrystals_ml_kem_1024_avx2_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_1024_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_1024_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_C_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_C_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_C_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); + +#if defined(OQS_ENABLE_KEM_ml_kem_1024_x86_64) +extern int PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +#endif + +#if defined(OQS_ENABLE_KEM_ml_kem_1024_aarch64) +extern int PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); #endif OQS_API OQS_STATUS OQS_KEM_ml_kem_1024_keypair(uint8_t *public_key, uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_1024_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_1024_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_1024_avx2_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_keypair(public_key, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_keypair(public_key, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_1024_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_keypair(public_key, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_keypair(public_key, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_keypair(public_key, secret_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_1024_encaps(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_1024_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_1024_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_1024_avx2_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_enc(ciphertext, shared_secret, public_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_enc(ciphertext, shared_secret, public_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_1024_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_enc(ciphertext, shared_secret, public_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_enc(ciphertext, shared_secret, public_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_enc(ciphertext, shared_secret, public_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_1024_decaps(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_1024_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_1024_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_1024_avx2_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_X86_64_DEFAULT_dec(shared_secret, ciphertext, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_dec(shared_secret, ciphertext, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_1024_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_AARCH64_OPT_dec(shared_secret, ciphertext, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_dec(shared_secret, ciphertext, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_1024_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM1024_C_dec(shared_secret, ciphertext, secret_key); #endif } diff --git a/src/kem/ml_kem/kem_ml_kem_512.c b/src/kem/ml_kem/kem_ml_kem_512.c index f2dcde53d2..b51a6b2afd 100644 --- a/src/kem/ml_kem/kem_ml_kem_512.c +++ b/src/kem/ml_kem/kem_ml_kem_512.c @@ -30,61 +30,97 @@ OQS_KEM *OQS_KEM_ml_kem_512_new(void) { return kem; } -extern int pqcrystals_ml_kem_512_ref_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#if defined(OQS_ENABLE_KEM_ml_kem_512_avx2) -extern int pqcrystals_ml_kem_512_avx2_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_512_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_512_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_C_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_C_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_C_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); + +#if defined(OQS_ENABLE_KEM_ml_kem_512_x86_64) +extern int PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +#endif + +#if defined(OQS_ENABLE_KEM_ml_kem_512_aarch64) +extern int PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); #endif OQS_API OQS_STATUS OQS_KEM_ml_kem_512_keypair(uint8_t *public_key, uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_512_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_512_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_512_avx2_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_keypair(public_key, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_keypair(public_key, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_512_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_keypair(public_key, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_keypair(public_key, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_keypair(public_key, secret_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_512_encaps(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_512_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_512_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_512_avx2_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_enc(ciphertext, shared_secret, public_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_enc(ciphertext, shared_secret, public_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_512_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_enc(ciphertext, shared_secret, public_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_enc(ciphertext, shared_secret, public_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_enc(ciphertext, shared_secret, public_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_512_decaps(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_512_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_512_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_512_avx2_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_X86_64_DEFAULT_dec(shared_secret, ciphertext, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_dec(shared_secret, ciphertext, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_512_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_dec(shared_secret, ciphertext, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_dec(shared_secret, ciphertext, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_512_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM512_C_dec(shared_secret, ciphertext, secret_key); #endif } diff --git a/src/kem/ml_kem/kem_ml_kem_768.c b/src/kem/ml_kem/kem_ml_kem_768.c index 14eb6ba404..f902877a9c 100644 --- a/src/kem/ml_kem/kem_ml_kem_768.c +++ b/src/kem/ml_kem/kem_ml_kem_768.c @@ -30,61 +30,97 @@ OQS_KEM *OQS_KEM_ml_kem_768_new(void) { return kem; } -extern int pqcrystals_ml_kem_768_ref_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#if defined(OQS_ENABLE_KEM_ml_kem_768_avx2) -extern int pqcrystals_ml_kem_768_avx2_keypair(uint8_t *pk, uint8_t *sk); -extern int pqcrystals_ml_kem_768_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -extern int pqcrystals_ml_kem_768_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_C_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_C_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_C_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); + +#if defined(OQS_ENABLE_KEM_ml_kem_768_x86_64) +extern int PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); +#endif + +#if defined(OQS_ENABLE_KEM_ml_kem_768_aarch64) +extern int PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_keypair(uint8_t *pk, uint8_t *sk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); +extern int PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); #endif OQS_API OQS_STATUS OQS_KEM_ml_kem_768_keypair(uint8_t *public_key, uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_768_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_768_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_768_avx2_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_keypair(public_key, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_keypair(public_key, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_768_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_keypair(public_key, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_keypair(public_key, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_keypair(public_key, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_keypair(public_key, secret_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_768_encaps(uint8_t *ciphertext, uint8_t *shared_secret, const uint8_t *public_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_768_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_768_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_768_avx2_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_enc(ciphertext, shared_secret, public_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_enc(ciphertext, shared_secret, public_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_768_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_enc(ciphertext, shared_secret, public_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_enc(ciphertext, shared_secret, public_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_enc(ciphertext, shared_secret, public_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_enc(ciphertext, shared_secret, public_key); #endif } OQS_API OQS_STATUS OQS_KEM_ml_kem_768_decaps(uint8_t *shared_secret, const uint8_t *ciphertext, const uint8_t *secret_key) { -#if defined(OQS_ENABLE_KEM_ml_kem_768_avx2) +#if defined(OQS_ENABLE_KEM_ml_kem_768_x86_64) #if defined(OQS_DIST_BUILD) if (OQS_CPU_has_extension(OQS_CPU_EXT_AVX2) && OQS_CPU_has_extension(OQS_CPU_EXT_BMI2) && OQS_CPU_has_extension(OQS_CPU_EXT_POPCNT)) { #endif /* OQS_DIST_BUILD */ - return (OQS_STATUS) pqcrystals_ml_kem_768_avx2_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_X86_64_DEFAULT_dec(shared_secret, ciphertext, secret_key); +#if defined(OQS_DIST_BUILD) + } else { + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_dec(shared_secret, ciphertext, secret_key); + } +#endif /* OQS_DIST_BUILD */ +#elif defined(OQS_ENABLE_KEM_ml_kem_768_aarch64) +#if defined(OQS_DIST_BUILD) + if (OQS_CPU_has_extension(OQS_CPU_EXT_ARM_NEON)) { +#endif /* OQS_DIST_BUILD */ + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_AARCH64_OPT_dec(shared_secret, ciphertext, secret_key); #if defined(OQS_DIST_BUILD) } else { - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_dec(shared_secret, ciphertext, secret_key); } #endif /* OQS_DIST_BUILD */ #else - return (OQS_STATUS) pqcrystals_ml_kem_768_ref_dec(shared_secret, ciphertext, secret_key); + return (OQS_STATUS) PQCP_MLKEM_NATIVE_MLKEM768_C_dec(shared_secret, ciphertext, secret_key); #endif } diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/README.md new file mode 100644 index 0000000000..e499a4a229 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/README.md @@ -0,0 +1,19 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# AArch64 backend (little endian) + +This directory contains a native backend for little endian AArch64 systems. It is derived from the following research +works: + +- _Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1_, Hanno Becker, Vincent Hwang, Matthias + J. Kannwischer, Bo-Yin Yang, and Shang-Yi Yang, [https://eprint.iacr.org/2021/986](https://eprint.iacr.org/2021/986) +- _Fast and Clean: Auditable high-performance assembly via constraint solving_, Amin Abdulrahman, Hanno Becker, Matthias + J. Kannwischer, Fabien Klein, [https://eprint.iacr.org/2022/1303](https://eprint.iacr.org/2022/1303) + +## Profiles + +This backend comes with two profiles: "clean" and optimized. The "clean" backend is handwritten and meant to be easy to +read and modify; for example, is heavily leverages register aliases and assembly macros. The optimized profile is +automatically generated from the clean profile via [SLOTHY](https://github.com/slothy-optimizer/slothy). Currently, the +target architecture is Cortex-A55, but you can easily re-optimize the code for a different microarchitecture supported +by SLOTHY, by adjusting the parameters in [optimize.sh](src/optimize.sh). diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/clean.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/clean.h new file mode 100644 index 0000000000..43a401dfc4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/clean.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_CLEAN + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/clean_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/opt.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/opt.h new file mode 100644 index 0000000000..04323c3e79 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/opt.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_OPT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/opt_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/aarch64_zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/aarch64_zetas.c new file mode 100644 index 0000000000..1e189fd995 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/aarch64_zetas.c @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Table of zeta values used in the AArch64 forward NTT + * See autogen for details. + */ +ALIGN const int16_t aarch64_ntt_zetas_layer01234[] = { + -1600, -15749, -749, -7373, -40, -394, -687, -6762, 630, 6201, + -1432, -14095, 848, 8347, 0, 0, 1062, 10453, 296, 2914, + -882, -8682, 0, 0, -1410, -13879, 1339, 13180, 1476, 14529, + 0, 0, 193, 1900, -283, -2786, 56, 551, 0, 0, + 797, 7845, -1089, -10719, 1333, 13121, 0, 0, -543, -5345, + 1426, 14036, -1235, -12156, 0, 0, -69, -679, 535, 5266, + -447, -4400, 0, 0, 569, 5601, -936, -9213, -450, -4429, + 0, 0, -1583, -15582, -1355, -13338, 821, 8081, 0, 0, +}; + +ALIGN const int16_t aarch64_ntt_zetas_layer56[] = { + 289, 289, 331, 331, -76, -76, -1573, -1573, 2845, + 2845, 3258, 3258, -748, -748, -15483, -15483, 17, 17, + 583, 583, 1637, 1637, -1041, -1041, 167, 167, 5739, + 5739, 16113, 16113, -10247, -10247, -568, -568, -680, -680, + 723, 723, 1100, 1100, -5591, -5591, -6693, -6693, 7117, + 7117, 10828, 10828, 1197, 1197, -1025, -1025, -1052, -1052, + -1274, -1274, 11782, 11782, -10089, -10089, -10355, -10355, -12540, + -12540, 1409, 1409, -48, -48, 756, 756, -314, -314, + 13869, 13869, -472, -472, 7441, 7441, -3091, -3091, -667, + -667, 233, 233, -1173, -1173, -279, -279, -6565, -6565, + 2293, 2293, -11546, -11546, -2746, -2746, 650, 650, -1352, + -1352, -816, -816, 632, 632, 6398, 6398, -13308, -13308, + -8032, -8032, 6221, 6221, -1626, -1626, -540, -540, -1482, + -1482, 1461, 1461, -16005, -16005, -5315, -5315, -14588, -14588, + 14381, 14381, 1651, 1651, -1540, -1540, 952, 952, -642, + -642, 16251, 16251, -15159, -15159, 9371, 9371, -6319, -6319, + -464, -464, 33, 33, 1320, 1320, -1414, -1414, -4567, + -4567, 325, 325, 12993, 12993, -13918, -13918, 939, 939, + -892, -892, 733, 733, 268, 268, 9243, 9243, -8780, + -8780, 7215, 7215, 2638, 2638, -1021, -1021, -941, -941, + -992, -992, 641, 641, -10050, -10050, -9262, -9262, -9764, + -9764, 6309, 6309, -1010, -1010, 1435, 1435, 807, 807, + 452, 452, -9942, -9942, 14125, 14125, 7943, 7943, 4449, + 4449, 1584, 1584, -1292, -1292, 375, 375, -1239, -1239, + 15592, 15592, -12717, -12717, 3691, 3691, -12196, -12196, -1031, + -1031, -109, -109, -780, -780, 1645, 1645, -10148, -10148, + -1073, -1073, -7678, -7678, 16192, 16192, 1438, 1438, -461, + -461, 1534, 1534, -927, -927, 14155, 14155, -4538, -4538, + 15099, 15099, -9125, -9125, 1063, 1063, -556, -556, -1230, + -1230, -863, -863, 10463, 10463, -5473, -5473, -12107, -12107, + -8495, -8495, 319, 319, 757, 757, 561, 561, -735, + -735, 3140, 3140, 7451, 7451, 5522, 5522, -7235, -7235, + -682, -682, -712, -712, 1481, 1481, 648, 648, -6713, + -6713, -7008, -7008, 14578, 14578, 6378, 6378, -525, -525, + 403, 403, 1143, 1143, -554, -554, -5168, -5168, 3967, + 3967, 11251, 11251, -5453, -5453, 1092, 1092, 1026, 1026, + -1179, -1179, 886, 886, 10749, 10749, 10099, 10099, -11605, + -11605, 8721, 8721, -855, -855, -219, -219, 1227, 1227, + 910, 910, -8416, -8416, -2156, -2156, 12078, 12078, 8957, + 8957, -1607, -1607, -1455, -1455, -1219, -1219, 885, 885, + -15818, -15818, -14322, -14322, -11999, -11999, 8711, 8711, 1212, + 1212, 1029, 1029, -394, -394, -1175, -1175, 11930, 11930, + 10129, 10129, -3878, -3878, -11566, -11566, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer01234[] = { + 1583, 15582, -821, -8081, 1355, 13338, 0, 0, -569, -5601, + 450, 4429, 936, 9213, 0, 0, 69, 679, 447, 4400, + -535, -5266, 0, 0, 543, 5345, 1235, 12156, -1426, -14036, + 0, 0, -797, -7845, -1333, -13121, 1089, 10719, 0, 0, + -193, -1900, -56, -551, 283, 2786, 0, 0, 1410, 13879, + -1476, -14529, -1339, -13180, 0, 0, -1062, -10453, 882, 8682, + -296, -2914, 0, 0, 1600, 15749, 40, 394, 749, 7373, + -848, -8347, 1432, 14095, -630, -6201, 687, 6762, 0, 0, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer56[] = { + -910, -910, -1227, -1227, 219, 219, 855, 855, -8957, + -8957, -12078, -12078, 2156, 2156, 8416, 8416, 1175, 1175, + 394, 394, -1029, -1029, -1212, -1212, 11566, 11566, 3878, + 3878, -10129, -10129, -11930, -11930, -885, -885, 1219, 1219, + 1455, 1455, 1607, 1607, -8711, -8711, 11999, 11999, 14322, + 14322, 15818, 15818, -648, -648, -1481, -1481, 712, 712, + 682, 682, -6378, -6378, -14578, -14578, 7008, 7008, 6713, + 6713, -886, -886, 1179, 1179, -1026, -1026, -1092, -1092, + -8721, -8721, 11605, 11605, -10099, -10099, -10749, -10749, 554, + 554, -1143, -1143, -403, -403, 525, 525, 5453, 5453, + -11251, -11251, -3967, -3967, 5168, 5168, 927, 927, -1534, + -1534, 461, 461, -1438, -1438, 9125, 9125, -15099, -15099, + 4538, 4538, -14155, -14155, 735, 735, -561, -561, -757, + -757, -319, -319, 7235, 7235, -5522, -5522, -7451, -7451, + -3140, -3140, 863, 863, 1230, 1230, 556, 556, -1063, + -1063, 8495, 8495, 12107, 12107, 5473, 5473, -10463, -10463, + -452, -452, -807, -807, -1435, -1435, 1010, 1010, -4449, + -4449, -7943, -7943, -14125, -14125, 9942, 9942, -1645, -1645, + 780, 780, 109, 109, 1031, 1031, -16192, -16192, 7678, + 7678, 1073, 1073, 10148, 10148, 1239, 1239, -375, -375, + 1292, 1292, -1584, -1584, 12196, 12196, -3691, -3691, 12717, + 12717, -15592, -15592, 1414, 1414, -1320, -1320, -33, -33, + 464, 464, 13918, 13918, -12993, -12993, -325, -325, 4567, + 4567, -641, -641, 992, 992, 941, 941, 1021, 1021, + -6309, -6309, 9764, 9764, 9262, 9262, 10050, 10050, -268, + -268, -733, -733, 892, 892, -939, -939, -2638, -2638, + -7215, -7215, 8780, 8780, -9243, -9243, -632, -632, 816, + 816, 1352, 1352, -650, -650, -6221, -6221, 8032, 8032, + 13308, 13308, -6398, -6398, 642, 642, -952, -952, 1540, + 1540, -1651, -1651, 6319, 6319, -9371, -9371, 15159, 15159, + -16251, -16251, -1461, -1461, 1482, 1482, 540, 540, 1626, + 1626, -14381, -14381, 14588, 14588, 5315, 5315, 16005, 16005, + 1274, 1274, 1052, 1052, 1025, 1025, -1197, -1197, 12540, + 12540, 10355, 10355, 10089, 10089, -11782, -11782, 279, 279, + 1173, 1173, -233, -233, 667, 667, 2746, 2746, 11546, + 11546, -2293, -2293, 6565, 6565, 314, 314, -756, -756, + 48, 48, -1409, -1409, 3091, 3091, -7441, -7441, 472, + 472, -13869, -13869, 1573, 1573, 76, 76, -331, -331, + -289, -289, 15483, 15483, 748, 748, -3258, -3258, -2845, + -2845, -1100, -1100, -723, -723, 680, 680, 568, 568, + -10828, -10828, -7117, -7117, 6693, 6693, 5591, 5591, 1041, + 1041, -1637, -1637, -583, -583, -17, -17, 10247, 10247, + -16113, -16113, -5739, -5739, -167, -167, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_native[] = { + 17, -17, -568, 568, 583, -583, -680, 680, 1637, -1637, 723, + -723, -1041, 1041, 1100, -1100, 1409, -1409, -667, 667, -48, 48, + 233, -233, 756, -756, -1173, 1173, -314, 314, -279, 279, -1626, + 1626, 1651, -1651, -540, 540, -1540, 1540, -1482, 1482, 952, -952, + 1461, -1461, -642, 642, 939, -939, -1021, 1021, -892, 892, -941, + 941, 733, -733, -992, 992, 268, -268, 641, -641, 1584, -1584, + -1031, 1031, -1292, 1292, -109, 109, 375, -375, -780, 780, -1239, + 1239, 1645, -1645, 1063, -1063, 319, -319, -556, 556, 757, -757, + -1230, 1230, 561, -561, -863, 863, -735, 735, -525, 525, 1092, + -1092, 403, -403, 1026, -1026, 1143, -1143, -1179, 1179, -554, 554, + 886, -886, -1607, 1607, 1212, -1212, -1455, 1455, 1029, -1029, -1219, + 1219, -394, 394, 885, -885, -1175, 1175, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_twisted_native[] = { + 167, -167, -5591, 5591, 5739, -5739, -6693, 6693, 16113, + -16113, 7117, -7117, -10247, 10247, 10828, -10828, 13869, -13869, + -6565, 6565, -472, 472, 2293, -2293, 7441, -7441, -11546, + 11546, -3091, 3091, -2746, 2746, -16005, 16005, 16251, -16251, + -5315, 5315, -15159, 15159, -14588, 14588, 9371, -9371, 14381, + -14381, -6319, 6319, 9243, -9243, -10050, 10050, -8780, 8780, + -9262, 9262, 7215, -7215, -9764, 9764, 2638, -2638, 6309, + -6309, 15592, -15592, -10148, 10148, -12717, 12717, -1073, 1073, + 3691, -3691, -7678, 7678, -12196, 12196, 16192, -16192, 10463, + -10463, 3140, -3140, -5473, 5473, 7451, -7451, -12107, 12107, + 5522, -5522, -8495, 8495, -7235, 7235, -5168, 5168, 10749, + -10749, 3967, -3967, 10099, -10099, 11251, -11251, -11605, 11605, + -5453, 5453, 8721, -8721, -15818, 15818, 11930, -11930, -14322, + 14322, 10129, -10129, -11999, 11999, -3878, 3878, 8711, -8711, + -11566, 11566, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_zetas MLKEM_NAMESPACE(empty_cu_aarch64_zetas) +int empty_cu_aarch64_zetas; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/arith_native_aarch64.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/arith_native_aarch64.h new file mode 100644 index 0000000000..6a5ee8a7d6 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/arith_native_aarch64.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_AARCH64_NATIVE_H +#define MLKEM_AARCH64_NATIVE_H + +#include +#include "common.h" + +#define aarch64_ntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_ntt_zetas_layer01234) +#define aarch64_ntt_zetas_layer56 MLKEM_NAMESPACE(aarch64_ntt_zetas_layer56) +#define aarch64_invntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer01234) +#define aarch64_invntt_zetas_layer56 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer56) +#define aarch64_zetas_mulcache_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_native) +#define aarch64_zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_twisted_native) +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) + +extern const int16_t aarch64_ntt_zetas_layer01234[]; +extern const int16_t aarch64_ntt_zetas_layer56[]; +extern const int16_t aarch64_invntt_zetas_layer01234[]; +extern const int16_t aarch64_invntt_zetas_layer56[]; +extern const int16_t aarch64_zetas_mulcache_native[]; +extern const int16_t aarch64_zetas_mulcache_twisted_native[]; +extern const uint8_t rej_uniform_table[]; + +#define ntt_asm_clean MLKEM_NAMESPACE(ntt_asm_clean) +void ntt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define ntt_asm_opt MLKEM_NAMESPACE(ntt_asm_opt) +void ntt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_clean MLKEM_NAMESPACE(intt_asm_clean) +void intt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_opt MLKEM_NAMESPACE(intt_asm_opt) +void intt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define rej_uniform_asm_clean MLKEM_NAMESPACE(rej_uniform_asm_clean) +unsigned int rej_uniform_asm_clean(int16_t *r, const uint8_t *buf, + unsigned int buflen, const uint8_t *table); + +#define poly_reduce_asm_clean MLKEM_NAMESPACE(poly_reduce_asm_clean) +void poly_reduce_asm_clean(int16_t *); + +#define poly_reduce_asm_opt MLKEM_NAMESPACE(poly_reduce_asm_opt) +void poly_reduce_asm_opt(int16_t *); + +#define poly_tomont_asm_clean MLKEM_NAMESPACE(poly_tomont_asm_clean) +void poly_tomont_asm_clean(int16_t *); + +#define poly_tomont_asm_opt MLKEM_NAMESPACE(poly_tomont_asm_opt) +void poly_tomont_asm_opt(int16_t *); + +#define poly_mulcache_compute_asm_clean \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_clean) +void poly_mulcache_compute_asm_clean(int16_t *, const int16_t *, + const int16_t *, const int16_t *); + + +#define poly_mulcache_compute_asm_opt \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_opt) +void poly_mulcache_compute_asm_opt(int16_t *, const int16_t *, const int16_t *, + const int16_t *); + +#define poly_tobytes_asm_clean MLKEM_NAMESPACE(poly_tobytes_asm_clean) +void poly_tobytes_asm_clean(uint8_t *r, const int16_t *a); + +#define poly_tobytes_asm_opt MLKEM_NAMESPACE(poly_tobytes_asm_opt) +void poly_tobytes_asm_opt(uint8_t *r, const int16_t *a); + +#define polyvec_basemul_acc_montgomery_cached_asm_clean \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) +void polyvec_basemul_acc_montgomery_cached_asm_clean(int16_t *r, + const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#define polyvec_basemul_acc_montgomery_cached_asm_opt \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) +void polyvec_basemul_acc_montgomery_cached_asm_opt(int16_t *r, const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#endif /* MLKEM_AARCH64_NATIVE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/clean_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/clean_impl.h new file mode 100644 index 0000000000..b0ff3d5972 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/clean_impl.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +static INLINE void ntt_native(poly *data) +{ + ntt_asm_clean(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_clean(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_clean(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_clean(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_clean(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_clean( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_clean(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/consts.h new file mode 100644 index 0000000000..c40947299c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/consts.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_AARCH64_CONSTS) +#define MLKEM_NATIVE_AARCH64_CONSTS + +#include +#include "common.h" + +#define zetas_mulcache_native MLKEM_NAMESPACE(zetas_mulcache_native) +extern const int16_t zetas_mulcache_native[256]; + +#define zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(zetas_mulcache_twisted_native) +extern const int16_t zetas_mulcache_twisted_native[256]; + +#endif /* MLKEM_NATIVE_AARCH64_CONSTS */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_clean.S new file mode 100644 index 0000000000..623a82ae9c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_clean.S @@ -0,0 +1,364 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_clean) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_clean): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 +layer3456_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + transpose4 data // manual ld4 + + load_next_roots_56 + + // Layer 7 + gs_butterfly_v data0, data1, root1, root1_tw + gs_butterfly_v data2, data3, root2, root2_tw + // Bounds: + // data0, data2: < 2q + // data1, data3: < q + + // Layer 6 + gs_butterfly_v data0, data2, root0, root0_tw + gs_butterfly_v data1, data3, root0, root0_tw + // Bounds: + // data0: < 4q + // data1: < 2q + // data2, data3: < q + + transpose4 data + + load_next_roots_34 + + // Layer 5 + gs_butterfly data0, data1, root0, 2, 3 + gs_butterfly data2, data3, root0, 4, 5 + // Max bound: 8q + + // Not all of those reductions are needed, but the bounds tracking + // is easier if we uniformly reduce at this point. + barrett_reduce data0 + barrett_reduce data2 + barrett_reduce data1 + barrett_reduce data3 + + // Bounds: q/2 + + // Layer 4 + gs_butterfly data0, data2, root0, 0, 1 + gs_butterfly data1, data3, root0, 0, 1 + // Bounds: < q + + str q_data0, [inp], #(64) + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, layer3456_start + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + +layer012_start: + + ldr q_data0, [in, #0] + ldr q_data1, [in, #(1*(512/8))] + ldr q_data2, [in, #(2*(512/8))] + ldr q_data3, [in, #(3*(512/8))] + ldr q_data4, [in, #(4*(512/8))] + ldr q_data5, [in, #(5*(512/8))] + ldr q_data6, [in, #(6*(512/8))] + ldr q_data7, [in, #(7*(512/8))] + + gs_butterfly data0, data1, root0, 6, 7 + gs_butterfly data2, data3, root1, 0, 1 + gs_butterfly data4, data5, root1, 2, 3 + gs_butterfly data6, data7, root1, 4, 5 + + gs_butterfly data0, data2, root0, 2, 3 + gs_butterfly data1, data3, root0, 2, 3 + gs_butterfly data4, data6, root0, 4, 5 + gs_butterfly data5, data7, root0, 4, 5 + + gs_butterfly data0, data4, root0, 0, 1 + gs_butterfly data1, data5, root0, 0, 1 + gs_butterfly data2, data6, root0, 0, 1 + gs_butterfly data3, data7, root0, 0, 1 + + // Bounds: < 8q + + str q_data4, [in, #(4*(512/8))] + str q_data5, [in, #(5*(512/8))] + str q_data6, [in, #(6*(512/8))] + str q_data7, [in, #(7*(512/8))] + + str q_data0, [in], #(16) + str q_data1, [in, #(-16 + 1*(512/8))] + str q_data2, [in, #(-16 + 2*(512/8))] + str q_data3, [in, #(-16 + 3*(512/8))] + + subs count, count, #1 + cbnz count, layer012_start + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_opt.S new file mode 100644 index 0000000000..e332efef8f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/intt_opt.S @@ -0,0 +1,1020 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_opt) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_opt): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 + // Instructions: 11 + // Expected cycles: 20 + // Expected IPC: 0.55 + // + // Cycle bound: 20.0 + // IPC bound: 0.55 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x3, #0] // *............................. + ldr q8, [x3, #16] // ..*........................... + ldr q24, [x3, #32] // ....*......................... + ldr q16, [x3, #48] // ......*....................... + ldr q9, [x2], #(6*16) // ........*..................... + trn1 v0.4S, v24.4S, v16.4S // ..........*................... + ldr q6, [x2, #-80] // ...........*.................. + ldr q3, [x2, #-64] // .............*................ + ldr q15, [x2, #-48] // ...............*.............. + ldr q4, [x2, #-32] // .................*............ + ldr q28, [x2, #-16] // ...................*.......... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q26, [x3, #0] // *.............................. + // ldr q8, [x3, #16] // ..*............................ + // ldr q24, [x3, #32] // ....*.......................... + // ldr q16, [x3, #48] // ......*........................ + // trn1 v0.4S, v24.4S, v16.4S // ..........*.................... + // ldr q9, [x2], #(6*16) // ........*...................... + // ldr q6, [x2, #-80] // ...........*................... + // ldr q3, [x2, #-64] // .............*................. + // ldr q15, [x2, #-48] // ...............*............... + // ldr q4, [x2, #-32] // .................*............. + // ldr q28, [x2, #-16] // ...................*........... + + sub count, count, #1 +layer3456_start: + // Instructions: 83 + // Expected cycles: 94 + // Expected IPC: 0.88 + // + // Cycle bound: 94.0 + // IPC bound: 0.88 + // + // Wall time: 3.34s + // User time: 3.34s + // + // ------------------------------------- cycle (expected) --------------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------------------ + trn1 v12.4S, v26.4S, v8.4S // *............................................................................................. + trn2 v26.4S, v26.4S, v8.4S // .*............................................................................................ + trn2 v8.4S, v24.4S, v16.4S // ..*........................................................................................... + trn2 v11.2D, v12.2D, v0.2D // ...*.......................................................................................... + trn1 v12.2D, v12.2D, v0.2D // ....*......................................................................................... + trn2 v16.2D, v26.2D, v8.2D // .....*........................................................................................ + trn1 v26.2D, v26.2D, v8.2D // ......*....................................................................................... + sub v8.8H, v11.8H, v16.8H // .......*...................................................................................... + add v11.8H, v11.8H, v16.8H // ........*..................................................................................... + sub v16.8H, v12.8H, v26.8H // .........*.................................................................................... + add v12.8H, v12.8H, v26.8H // ..........*................................................................................... + sqrdmulh v26.8H, v8.8H, v28.8H // ...........*.................................................................................. + sqrdmulh v15.8H, v16.8H, v15.8H // ............*................................................................................. + mul v16.8H, v16.8H, v3.8H // .............*................................................................................ + mul v8.8H, v8.8H, v4.8H // ..............*............................................................................... + sub v0.8H, v12.8H, v11.8H // ...............*.............................................................................. + add v12.8H, v12.8H, v11.8H // ................*............................................................................. + mls v16.8H, v15.8H, v7.H[0] // .................*............................................................................ + mls v8.8H, v26.8H, v7.H[0] // ..................*........................................................................... + sqrdmulh v26.8H, v0.8H, v6.8H // ...................*.......................................................................... + mul v11.8H, v0.8H, v9.8H // ....................*......................................................................... + ldr q15, [x1], #16 // .....................*........................................................................ + sub v0.8H, v16.8H, v8.8H // .......................*...................................................................... + mls v11.8H, v26.8H, v7.H[0] // ........................*..................................................................... + add v26.8H, v16.8H, v8.8H // .........................*.................................................................... + sqrdmulh v8.8H, v0.8H, v6.8H // ..........................*................................................................... + mul v16.8H, v0.8H, v9.8H // ...........................*.................................................................. + trn1 v0.4S, v12.4S, v26.4S // ............................*................................................................. + trn2 v12.4S, v12.4S, v26.4S // .............................*................................................................ + ldr q26, [x3, #64] // ..............................e............................................................... + mls v16.8H, v8.8H, v7.H[0] // ................................*............................................................. + ldr q8, [x3, #80] // .................................e............................................................ + ldr q24, [x3, #96] // ...................................e.......................................................... + trn1 v9.4S, v11.4S, v16.4S // .....................................*........................................................ + trn2 v11.4S, v11.4S, v16.4S // ......................................*....................................................... + ldr q16, [x3, #112] // .......................................e...................................................... + trn2 v6.2D, v0.2D, v9.2D // .........................................*.................................................... + trn2 v3.2D, v12.2D, v11.2D // ..........................................*................................................... + trn1 v0.2D, v0.2D, v9.2D // ...........................................*.................................................. + trn1 v12.2D, v12.2D, v11.2D // ............................................*................................................. + sub v11.8H, v6.8H, v3.8H // .............................................*................................................ + sub v9.8H, v0.8H, v12.8H // ..............................................*............................................... + add v12.8H, v0.8H, v12.8H // ...............................................*.............................................. + sqrdmulh v0.8H, v11.8H, v15.H[5] // ................................................*............................................. + sqrdmulh v4.8H, v9.8H, v15.H[3] // .................................................*............................................ + mul v9.8H, v9.8H, v15.H[2] // ..................................................*........................................... + mul v11.8H, v11.8H, v15.H[4] // ...................................................*.......................................... + add v6.8H, v6.8H, v3.8H // ....................................................*......................................... + sqdmulh v3.8H, v12.8H, v7.H[1] // .....................................................*........................................ + mls v9.8H, v4.8H, v7.H[0] // ......................................................*....................................... + mls v11.8H, v0.8H, v7.H[0] // .......................................................*...................................... + sqdmulh v0.8H, v6.8H, v7.H[1] // ........................................................*..................................... + srshr v3.8H, v3.8H, #11 // .........................................................*.................................... + sqdmulh v4.8H, v9.8H, v7.H[1] // ..........................................................*................................... + sqdmulh v28.8H, v11.8H, v7.H[1] // ...........................................................*.................................. + mls v12.8H, v3.8H, v7.H[0] // ............................................................*................................. + srshr v0.8H, v0.8H, #11 // .............................................................*................................ + srshr v3.8H, v4.8H, #11 // ..............................................................*............................... + srshr v4.8H, v28.8H, #11 // ...............................................................*.............................. + mls v6.8H, v0.8H, v7.H[0] // ................................................................*............................. + mls v9.8H, v3.8H, v7.H[0] // .................................................................*............................ + mls v11.8H, v4.8H, v7.H[0] // ..................................................................*........................... + trn1 v0.4S, v24.4S, v16.4S // ...................................................................e.......................... + sub v3.8H, v12.8H, v6.8H // ....................................................................*......................... + add v12.8H, v12.8H, v6.8H // .....................................................................*........................ + sub v6.8H, v9.8H, v11.8H // ......................................................................*....................... + sqrdmulh v4.8H, v3.8H, v15.H[1] // .......................................................................*...................... + mul v3.8H, v3.8H, v15.H[0] // ........................................................................*..................... + sqrdmulh v28.8H, v6.8H, v15.H[1] // .........................................................................*.................... + mul v15.8H, v6.8H, v15.H[0] // ..........................................................................*................... + add v11.8H, v9.8H, v11.8H // ...........................................................................*.................. + mls v3.8H, v4.8H, v7.H[0] // ............................................................................*................. + str q12, [x3], #(64) // .............................................................................*................ + mls v15.8H, v28.8H, v7.H[0] // ..............................................................................*............... + str q11, [x3, #-48] // ...............................................................................*.............. + ldr q9, [x2], #(6*16) // ................................................................................e............. + str q3, [x3, #-32] // ..................................................................................*........... + ldr q6, [x2, #-80] // ...................................................................................e.......... + str q15, [x3, #-16] // .....................................................................................*........ + ldr q3, [x2, #-64] // ......................................................................................e....... + ldr q15, [x2, #-48] // ........................................................................................e..... + ldr q4, [x2, #-32] // ..........................................................................................e... + ldr q28, [x2, #-16] // ............................................................................................e. + + // ----------------------------------------------------------------- cycle (expected) ------------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------ + // ldr q8, [x3, #(16*0)] // e...............................................................'.............................~....................................................... + // ldr q9, [x3, #(16*1)] // ...e............................................................'................................~.................................................... + // ldr q10, [x3, #(16*2)] // .....e..........................................................'..................................~.................................................. + // ldr q11, [x3, #(16*3)] // .........e......................................................'......................................~.............................................. + // trn1 v25.4s, v8.4s, v9.4s // ................................................................*..................................................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'*.................................................................................... + // trn1 v27.4s, v10.4s, v11.4s // .....................................e..........................'..................................................................~.................. + // trn2 v28.4s, v10.4s, v11.4s // ................................................................'.*................................................................................... + // trn2 v10.2d, v25.2d, v27.2d // ................................................................'..*.................................................................................. + // trn2 v11.2d, v26.2d, v28.2d // ................................................................'....*................................................................................ + // trn1 v8.2d, v25.2d, v27.2d // ................................................................'...*................................................................................. + // trn1 v9.2d, v26.2d, v28.2d // ................................................................'.....*............................................................................... + // ldr q0, [x2], #(6*16) // ..................................................e.............'...............................................................................~..... + // ldr q4, [x2, #(-6*16 + 1*16)] // .....................................................e..........'..................................................................................~.. + // ldr q1, [x2, #(-6*16 + 2*16)] // ........................................................e.......'..................................................................................... + // ldr q5, [x2, #(-6*16 + 3*16)] // ..........................................................e.....'..................................................................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ............................................................e...'..................................................................................... + // ldr q6, [x2, #(-6*16 + 5*16)] // ..............................................................e.'..................................................................................... + // sub v24.8h, v8.8h, v9.8h // ................................................................'........*............................................................................ + // add v8.8h, v8.8h, v9.8h // ................................................................'.........*........................................................................... + // sqrdmulh v27.8h, v24.8h, v5.8h // ................................................................'...........*......................................................................... + // mul v9.8h, v24.8h, v1.8h // ................................................................'............*........................................................................ + // mls v9.8h, v27.8h, v7.h[0] // ................................................................'................*.................................................................... + // sub v24.8h, v10.8h, v11.8h // ................................................................'......*.............................................................................. + // add v10.8h, v10.8h, v11.8h // ................................................................'.......*............................................................................. + // sqrdmulh v27.8h, v24.8h, v6.8h // ................................................................'..........*.......................................................................... + // mul v11.8h, v24.8h, v2.8h // ................................................................'.............*....................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ................................................................'.................*................................................................... + // sub v24.8h, v8.8h, v10.8h // ................................................................'..............*...................................................................... + // add v8.8h, v8.8h, v10.8h // ................................................................'...............*..................................................................... + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'..................*.................................................................. + // mul v10.8h, v24.8h, v0.8h // ................................................................'...................*................................................................. + // mls v10.8h, v27.8h, v7.h[0] // ................................................................'.......................*............................................................. + // sub v24.8h, v9.8h, v11.8h // ................................................................'......................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ................................................................'........................*............................................................ + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'.........................*........................................................... + // mul v11.8h, v24.8h, v0.8h // ................................................................'..........................*.......................................................... + // mls v11.8h, v27.8h, v7.h[0] // ..~.............................................................'...............................*..................................................... + // trn1 v25.4s, v8.4s, v9.4s // ................................................................'...........................*......................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'............................*........................................................ + // trn1 v27.4s, v10.4s, v11.4s // .......~........................................................'....................................*................................................ + // trn2 v28.4s, v10.4s, v11.4s // ........~.......................................................'.....................................*............................................... + // trn2 v10.2d, v25.2d, v27.2d // ...........~....................................................'........................................*............................................ + // trn2 v11.2d, v26.2d, v28.2d // ............~...................................................'.........................................*........................................... + // trn1 v8.2d, v25.2d, v27.2d // .............~..................................................'..........................................*.......................................... + // trn1 v9.2d, v26.2d, v28.2d // ..............~.................................................'...........................................*......................................... + // ldr q0, [x1], #16 // ................................................................'....................*................................................................ + // sub v24.8h, v8.8h, v9.8h // ................~...............................................'.............................................*....................................... + // add v8.8h, v8.8h, v9.8h // .................~..............................................'..............................................*...................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...................~............................................'................................................*.................................... + // mul v9.8h, v24.8h, v0.h[2] // ....................~...........................................'.................................................*................................... + // mls v9.8h, v27.8h, v7.h[0] // ........................~.......................................'.....................................................*............................... + // sub v24.8h, v10.8h, v11.8h // ...............~................................................'............................................*........................................ + // add v10.8h, v10.8h, v11.8h // ......................~.........................................'...................................................*................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ..................~.............................................'...............................................*..................................... + // mul v11.8h, v24.8h, v0.h[4] // .....................~..........................................'..................................................*.................................. + // mls v11.8h, v27.8h, v7.h[0] // .........................~......................................'......................................................*.............................. + // sqdmulh v25.8h, v8.8h, v7.h[1] // .......................~........................................'....................................................*................................ + // srshr v25.8h, v25.8h, #11 // ...........................~....................................'........................................................*............................ + // mls v8.8h, v25.8h, v7.h[0] // ..............................~.................................'...........................................................*......................... + // sqdmulh v25.8h, v10.8h, v7.h[1] // ..........................~.....................................'.......................................................*............................. + // srshr v25.8h, v25.8h, #11 // ...............................~................................'............................................................*........................ + // mls v10.8h, v25.8h, v7.h[0] // ..................................~.............................'...............................................................*..................... + // sqdmulh v25.8h, v9.8h, v7.h[1] // ............................~...................................'.........................................................*........................... + // srshr v25.8h, v25.8h, #11 // ................................~...............................'.............................................................*....................... + // mls v9.8h, v25.8h, v7.h[0] // ...................................~............................'................................................................*.................... + // sqdmulh v25.8h, v11.8h, v7.h[1] // .............................~..................................'..........................................................*.......................... + // srshr v25.8h, v25.8h, #11 // .................................~..............................'..............................................................*...................... + // mls v11.8h, v25.8h, v7.h[0] // ....................................~...........................'.................................................................*................... + // sub v24.8h, v8.8h, v10.8h // ......................................~.........................'...................................................................*................. + // add v8.8h, v8.8h, v10.8h // .......................................~........................'....................................................................*................ + // sqrdmulh v27.8h, v24.8h, v0.h[1] // .........................................~......................'......................................................................*.............. + // mul v10.8h, v24.8h, v0.h[0] // ..........................................~.....................'.......................................................................*............. + // mls v10.8h, v27.8h, v7.h[0] // ..............................................~.................'...........................................................................*......... + // sub v24.8h, v9.8h, v11.8h // ........................................~.......................'.....................................................................*............... + // add v9.8h, v9.8h, v11.8h // .............................................~..................'..........................................................................*.......... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................................~....................'........................................................................*............ + // mul v11.8h, v24.8h, v0.h[0] // ............................................~...................'.........................................................................*........... + // mls v11.8h, v27.8h, v7.h[0] // ................................................~...............'.............................................................................*....... + // str q8, [x3], #(64) // ...............................................~................'............................................................................*........ + // str q9, [x3, #(-64 + 16*1)] // .................................................~..............'..............................................................................*...... + // str q10, [x3, #(-64 + 16*2)] // ....................................................~...........'.................................................................................*... + // str q11, [x3, #(-64 + 16*3)] // .......................................................~........'....................................................................................* + + sub count, count, #1 + cbnz count, layer3456_start + // Instructions: 72 + // Expected cycles: 79 + // Expected IPC: 0.91 + // + // Cycle bound: 79.0 + // IPC bound: 0.91 + // + // Wall time: 9.28s + // User time: 9.28s + // + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + trn1 v11.4S, v26.4S, v8.4S // *.............................................................................. + trn2 v24.4S, v24.4S, v16.4S // .*............................................................................. + trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + trn1 v18.2D, v11.2D, v0.2D // ...*........................................................................... + trn2 v11.2D, v11.2D, v0.2D // ....*.......................................................................... + trn2 v12.2D, v26.2D, v24.2D // .....*......................................................................... + trn1 v8.2D, v26.2D, v24.2D // ......*........................................................................ + sub v26.8H, v11.8H, v12.8H // .......*....................................................................... + sub v13.8H, v18.8H, v8.8H // ........*...................................................................... + add v24.8H, v18.8H, v8.8H // .........*..................................................................... + mul v16.8H, v26.8H, v4.8H // ..........*.................................................................... + sqrdmulh v17.8H, v13.8H, v15.8H // ...........*................................................................... + mul v3.8H, v13.8H, v3.8H // ............*.................................................................. + sqrdmulh v26.8H, v26.8H, v28.8H // .............*................................................................. + add v10.8H, v11.8H, v12.8H // ..............*................................................................ + mls v3.8H, v17.8H, v7.H[0] // ................*.............................................................. + mls v16.8H, v26.8H, v7.H[0] // .................*............................................................. + sub v26.8H, v24.8H, v10.8H // ..................*............................................................ + ldr q4, [x1], #16 // ...................*........................................................... + sub v12.8H, v3.8H, v16.8H // .....................*......................................................... + sqrdmulh v15.8H, v26.8H, v6.8H // ......................*........................................................ + mul v11.8H, v26.8H, v9.8H // .......................*....................................................... + mul v8.8H, v12.8H, v9.8H // ........................*...................................................... + sqrdmulh v12.8H, v12.8H, v6.8H // .........................*..................................................... + add v0.8H, v24.8H, v10.8H // ..........................*.................................................... + mls v11.8H, v15.8H, v7.H[0] // ...........................*................................................... + add v6.8H, v3.8H, v16.8H // ............................*.................................................. + mls v8.8H, v12.8H, v7.H[0] // .............................*................................................. + trn2 v26.4S, v0.4S, v6.4S // ...............................*............................................... + trn2 v12.4S, v11.4S, v8.4S // .................................*............................................. + trn1 v3.4S, v11.4S, v8.4S // ..................................*............................................ + trn1 v17.4S, v0.4S, v6.4S // ...................................*........................................... + trn1 v8.2D, v26.2D, v12.2D // ....................................*.......................................... + trn2 v13.2D, v26.2D, v12.2D // .....................................*......................................... + trn1 v11.2D, v17.2D, v3.2D // ......................................*........................................ + trn2 v15.2D, v17.2D, v3.2D // .......................................*....................................... + sub v12.8H, v11.8H, v8.8H // ........................................*...................................... + add v16.8H, v15.8H, v13.8H // .........................................*..................................... + sub v26.8H, v15.8H, v13.8H // ..........................................*.................................... + mul v0.8H, v12.8H, v4.H[2] // ...........................................*................................... + sqrdmulh v9.8H, v12.8H, v4.H[3] // ............................................*.................................. + mul v13.8H, v26.8H, v4.H[4] // .............................................*................................. + sqrdmulh v26.8H, v26.8H, v4.H[5] // ..............................................*................................ + add v24.8H, v11.8H, v8.8H // ...............................................*............................... + mls v0.8H, v9.8H, v7.H[0] // ................................................*.............................. + sqdmulh v12.8H, v16.8H, v7.H[1] // .................................................*............................. + mls v13.8H, v26.8H, v7.H[0] // ..................................................*............................ + sqdmulh v11.8H, v24.8H, v7.H[1] // ...................................................*........................... + sqdmulh v8.8H, v0.8H, v7.H[1] // ....................................................*.......................... + srshr v12.8H, v12.8H, #11 // .....................................................*......................... + sqdmulh v26.8H, v13.8H, v7.H[1] // ......................................................*........................ + srshr v11.8H, v11.8H, #11 // .......................................................*....................... + mls v16.8H, v12.8H, v7.H[0] // ........................................................*...................... + srshr v8.8H, v8.8H, #11 // .........................................................*..................... + srshr v26.8H, v26.8H, #11 // ..........................................................*.................... + mls v24.8H, v11.8H, v7.H[0] // ...........................................................*................... + mls v0.8H, v8.8H, v7.H[0] // ............................................................*.................. + mls v13.8H, v26.8H, v7.H[0] // .............................................................*................. + sub v26.8H, v24.8H, v16.8H // ...............................................................*............... + add v15.8H, v24.8H, v16.8H // ................................................................*.............. + sub v12.8H, v0.8H, v13.8H // .................................................................*............. + mul v11.8H, v26.8H, v4.H[0] // ..................................................................*............ + sqrdmulh v16.8H, v26.8H, v4.H[1] // ...................................................................*........... + mul v26.8H, v12.8H, v4.H[0] // ....................................................................*.......... + sqrdmulh v8.8H, v12.8H, v4.H[1] // .....................................................................*......... + add v12.8H, v0.8H, v13.8H // ......................................................................*........ + mls v11.8H, v16.8H, v7.H[0] // .......................................................................*....... + str q15, [x3], #(64) // ........................................................................*...... + mls v26.8H, v8.8H, v7.H[0] // .........................................................................*..... + str q12, [x3, #-48] // ..........................................................................*.... + str q11, [x3, #-32] // ............................................................................*.. + str q26, [x3, #-16] // ..............................................................................* + + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + // trn1 v12.4S, v26.4S, v8.4S // *.............................................................................. + // trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + // trn2 v8.4S, v24.4S, v16.4S // .*............................................................................. + // trn2 v11.2D, v12.2D, v0.2D // ....*.......................................................................... + // trn1 v12.2D, v12.2D, v0.2D // ...*........................................................................... + // trn2 v16.2D, v26.2D, v8.2D // .....*......................................................................... + // trn1 v26.2D, v26.2D, v8.2D // ......*........................................................................ + // sub v8.8H, v11.8H, v16.8H // .......*....................................................................... + // add v11.8H, v11.8H, v16.8H // ..............*................................................................ + // sub v16.8H, v12.8H, v26.8H // ........*...................................................................... + // add v12.8H, v12.8H, v26.8H // .........*..................................................................... + // sqrdmulh v26.8H, v8.8H, v28.8H // .............*................................................................. + // sqrdmulh v15.8H, v16.8H, v15.8H // ...........*................................................................... + // mul v16.8H, v16.8H, v3.8H // ............*.................................................................. + // mul v8.8H, v8.8H, v4.8H // ..........*.................................................................... + // sub v0.8H, v12.8H, v11.8H // ..................*............................................................ + // add v12.8H, v12.8H, v11.8H // ..........................*.................................................... + // mls v16.8H, v15.8H, v7.H[0] // ................*.............................................................. + // mls v8.8H, v26.8H, v7.H[0] // .................*............................................................. + // sqrdmulh v26.8H, v0.8H, v6.8H // ......................*........................................................ + // mul v11.8H, v0.8H, v9.8H // .......................*....................................................... + // ldr q15, [x1], #16 // ...................*........................................................... + // sub v0.8H, v16.8H, v8.8H // .....................*......................................................... + // mls v11.8H, v26.8H, v7.H[0] // ...........................*................................................... + // add v26.8H, v16.8H, v8.8H // ............................*.................................................. + // sqrdmulh v8.8H, v0.8H, v6.8H // .........................*..................................................... + // mul v16.8H, v0.8H, v9.8H // ........................*...................................................... + // trn1 v0.4S, v12.4S, v26.4S // ...................................*........................................... + // trn2 v12.4S, v12.4S, v26.4S // ...............................*............................................... + // mls v16.8H, v8.8H, v7.H[0] // .............................*................................................. + // trn1 v9.4S, v11.4S, v16.4S // ..................................*............................................ + // trn2 v11.4S, v11.4S, v16.4S // .................................*............................................. + // trn2 v6.2D, v0.2D, v9.2D // .......................................*....................................... + // trn2 v3.2D, v12.2D, v11.2D // .....................................*......................................... + // trn1 v0.2D, v0.2D, v9.2D // ......................................*........................................ + // trn1 v12.2D, v12.2D, v11.2D // ....................................*.......................................... + // sub v11.8H, v6.8H, v3.8H // ..........................................*.................................... + // sub v9.8H, v0.8H, v12.8H // ........................................*...................................... + // add v12.8H, v0.8H, v12.8H // ...............................................*............................... + // sqrdmulh v0.8H, v11.8H, v15.H[5] // ..............................................*................................ + // sqrdmulh v4.8H, v9.8H, v15.H[3] // ............................................*.................................. + // mul v9.8H, v9.8H, v15.H[2] // ...........................................*................................... + // mul v11.8H, v11.8H, v15.H[4] // .............................................*................................. + // add v6.8H, v6.8H, v3.8H // .........................................*..................................... + // sqdmulh v3.8H, v12.8H, v7.H[1] // ...................................................*........................... + // mls v9.8H, v4.8H, v7.H[0] // ................................................*.............................. + // mls v11.8H, v0.8H, v7.H[0] // ..................................................*............................ + // sqdmulh v0.8H, v6.8H, v7.H[1] // .................................................*............................. + // srshr v3.8H, v3.8H, #11 // .......................................................*....................... + // sqdmulh v4.8H, v9.8H, v7.H[1] // ....................................................*.......................... + // sqdmulh v28.8H, v11.8H, v7.H[1] // ......................................................*........................ + // mls v12.8H, v3.8H, v7.H[0] // ...........................................................*................... + // srshr v0.8H, v0.8H, #11 // .....................................................*......................... + // srshr v3.8H, v4.8H, #11 // .........................................................*..................... + // srshr v4.8H, v28.8H, #11 // ..........................................................*.................... + // mls v6.8H, v0.8H, v7.H[0] // ........................................................*...................... + // mls v9.8H, v3.8H, v7.H[0] // ............................................................*.................. + // mls v11.8H, v4.8H, v7.H[0] // .............................................................*................. + // sub v3.8H, v12.8H, v6.8H // ...............................................................*............... + // add v12.8H, v12.8H, v6.8H // ................................................................*.............. + // sub v6.8H, v9.8H, v11.8H // .................................................................*............. + // sqrdmulh v4.8H, v3.8H, v15.H[1] // ...................................................................*........... + // mul v3.8H, v3.8H, v15.H[0] // ..................................................................*............ + // sqrdmulh v28.8H, v6.8H, v15.H[1] // .....................................................................*......... + // mul v15.8H, v6.8H, v15.H[0] // ....................................................................*.......... + // add v11.8H, v9.8H, v11.8H // ......................................................................*........ + // mls v3.8H, v4.8H, v7.H[0] // .......................................................................*....... + // str q12, [x3], #(64) // ........................................................................*...... + // mls v15.8H, v28.8H, v7.H[0] // .........................................................................*..... + // str q11, [x3, #-48] // ..........................................................................*.... + // str q3, [x3, #-32] // ............................................................................*.. + // str q15, [x3, #-16] // ..............................................................................* + + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + + // Instructions: 12 + // Expected cycles: 19 + // Expected IPC: 0.63 + // + // Cycle bound: 19.0 + // IPC bound: 0.63 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q24, [x0, #128] // *............................. + ldr q16, [x0, #192] // ..*........................... + ldr q9, [x0, #256] // ....*......................... + ldr q6, [x0, #320] // ......*....................... + ldr q3, [x0, #384] // ........*..................... + ldr q4, [x0, #448] // ..........*................... + add v28.8H, v9.8H, v6.8H // ............*................. + add v19.8H, v24.8H, v16.8H // .............*................ + add v13.8H, v3.8H, v4.8H // ..............*............... + ldr q11, [x0, #0] // ...............*.............. + add v23.8H, v28.8H, v13.8H // .................*............ + ldr q15, [x0, #64] // ..................*........... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q11, [x0, #0] // ...............*............... + // ldr q15, [x0, #64] // ..................*............ + // ldr q24, [x0, #128] // *.............................. + // ldr q16, [x0, #192] // ..*............................ + // ldr q9, [x0, #256] // ....*.......................... + // ldr q6, [x0, #320] // ......*........................ + // ldr q3, [x0, #384] // ........*...................... + // ldr q4, [x0, #448] // ..........*.................... + // add v28.8H, v9.8H, v6.8H // ............*.................. + // add v13.8H, v3.8H, v4.8H // ..............*................ + // add v19.8H, v24.8H, v16.8H // .............*................. + // add v23.8H, v28.8H, v13.8H // .................*............. + + sub count, count, #1 +layer012_start: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.81s + // User time: 2.81s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sub v12.8H, v11.8H, v15.8H // *................................................................................... + add v26.8H, v11.8H, v15.8H // .*.................................................................................. + sub v8.8H, v24.8H, v16.8H // ..*................................................................................. + sqrdmulh v11.8H, v12.8H, v0.H[7] // ...*................................................................................ + mul v12.8H, v12.8H, v0.H[6] // ....*............................................................................... + sub v16.8H, v26.8H, v19.8H // .....*.............................................................................. + add v26.8H, v26.8H, v19.8H // ......*............................................................................. + sqrdmulh v15.8H, v8.8H, v1.H[1] // .......*............................................................................ + mul v8.8H, v8.8H, v1.H[0] // ........*........................................................................... + mls v12.8H, v11.8H, v7.H[0] // .........*.......................................................................... + sub v11.8H, v9.8H, v6.8H // ..........*......................................................................... + sqrdmulh v24.8H, v16.8H, v0.H[3] // ...........*........................................................................ + mul v16.8H, v16.8H, v0.H[2] // ............*....................................................................... + sub v9.8H, v26.8H, v23.8H // .............*...................................................................... + add v26.8H, v26.8H, v23.8H // ..............*..................................................................... + mls v8.8H, v15.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v15.8H, v11.8H, v1.H[3] // ................*................................................................... + mul v11.8H, v11.8H, v1.H[2] // .................*.................................................................. + sub v6.8H, v3.8H, v4.8H // ..................*................................................................. + sub v3.8H, v12.8H, v8.8H // ...................*................................................................ + add v12.8H, v12.8H, v8.8H // ....................*............................................................... + mls v11.8H, v15.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v8.8H, v6.8H, v1.H[5] // ......................*............................................................. + mls v16.8H, v24.8H, v7.H[0] // .......................*............................................................ + mul v15.8H, v6.8H, v1.H[4] // ........................*........................................................... + sqrdmulh v24.8H, v3.8H, v0.H[3] // .........................*.......................................................... + mul v6.8H, v3.8H, v0.H[2] // ..........................*......................................................... + sqrdmulh v3.8H, v9.8H, v0.H[1] // ...........................*........................................................ + mul v9.8H, v9.8H, v0.H[0] // ............................*....................................................... + str q26, [x0], #(16) // .............................*...................................................... + mls v15.8H, v8.8H, v7.H[0] // ..............................*..................................................... + mls v6.8H, v24.8H, v7.H[0] // ...............................*.................................................... + sub v26.8H, v28.8H, v13.8H // ................................*................................................... + mls v9.8H, v3.8H, v7.H[0] // .................................*.................................................. + sub v8.8H, v11.8H, v15.8H // ..................................*................................................. + sqrdmulh v24.8H, v26.8H, v0.H[5] // ...................................*................................................ + mul v26.8H, v26.8H, v0.H[4] // ....................................*............................................... + add v11.8H, v11.8H, v15.8H // .....................................*.............................................. + sqrdmulh v15.8H, v8.8H, v0.H[5] // ......................................*............................................. + mul v8.8H, v8.8H, v0.H[4] // .......................................*............................................ + mls v26.8H, v24.8H, v7.H[0] // ........................................*........................................... + sub v24.8H, v12.8H, v11.8H // .........................................*.......................................... + add v12.8H, v12.8H, v11.8H // ..........................................*......................................... + mls v8.8H, v15.8H, v7.H[0] // ...........................................*........................................ + sqrdmulh v11.8H, v24.8H, v0.H[1] // ............................................*....................................... + mul v15.8H, v24.8H, v0.H[0] // .............................................*...................................... + sub v24.8H, v16.8H, v26.8H // ..............................................*..................................... + add v26.8H, v16.8H, v26.8H // ...............................................*.................................... + sub v16.8H, v6.8H, v8.8H // ................................................*................................... + mls v15.8H, v11.8H, v7.H[0] // .................................................*.................................. + sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................................................*................................. + mul v24.8H, v24.8H, v0.H[0] // ...................................................*................................ + add v8.8H, v6.8H, v8.8H // ....................................................*............................... + sqrdmulh v6.8H, v16.8H, v0.H[1] // .....................................................*.............................. + mul v16.8H, v16.8H, v0.H[0] // ......................................................*............................. + mls v24.8H, v11.8H, v7.H[0] // .......................................................*............................ + str q9, [x0, #240] // ........................................................*........................... + ldr q11, [x0, #0] // .........................................................e.......................... + mls v16.8H, v6.8H, v7.H[0] // ...........................................................*........................ + str q15, [x0, #304] // ............................................................*....................... + ldr q15, [x0, #64] // .............................................................e...................... + str q24, [x0, #368] // ...............................................................*.................... + ldr q24, [x0, #128] // ................................................................e................... + str q16, [x0, #432] // ..................................................................*................. + ldr q16, [x0, #192] // ...................................................................e................ + str q12, [x0, #48] // .....................................................................*.............. + ldr q9, [x0, #256] // ......................................................................e............. + ldr q6, [x0, #320] // ........................................................................e........... + ldr q3, [x0, #384] // ..........................................................................e......... + ldr q4, [x0, #448] // ............................................................................e....... + str q26, [x0, #112] // ..............................................................................*..... + add v28.8H, v9.8H, v6.8H // ...............................................................................e.... + add v13.8H, v3.8H, v4.8H // ................................................................................e... + str q8, [x0, #176] // .................................................................................*.. + add v19.8H, v24.8H, v16.8H // ..................................................................................e. + add v23.8H, v28.8H, v13.8H // ...................................................................................e + + // --------------------------------------------- cycle (expected) ---------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|-------- + // ldr q8, [x0, #0] // e..........................'........................................................~........................ + // ldr q9, [x0, #(1*(512/8))] // ....e......................'............................................................~.................... + // ldr q10, [x0, #(2*(512/8))] // .......e...................'...............................................................~................. + // ldr q11, [x0, #(3*(512/8))] // ..........e................'..................................................................~.............. + // ldr q12, [x0, #(4*(512/8))] // .............e.............'.....................................................................~........... + // ldr q13, [x0, #(5*(512/8))] // ...............e...........'.......................................................................~......... + // ldr q14, [x0, #(6*(512/8))] // .................e.........'.........................................................................~....... + // ldr q15, [x0, #(7*(512/8))] // ...................e.......'...........................................................................~..... + // sub v24.8h, v8.8h, v9.8h // ...........................*................................................................................. + // add v8.8h, v8.8h, v9.8h // ...........................'*................................................................................ + // sqrdmulh v27.8h, v24.8h, v0.h[7] // ...........................'..*.............................................................................. + // mul v9.8h, v24.8h, v0.h[6] // ...........................'...*............................................................................. + // mls v9.8h, v27.8h, v7.h[0] // ...........................'........*........................................................................ + // sub v24.8h, v10.8h, v11.8h // ...........................'.*............................................................................... + // add v10.8h, v10.8h, v11.8h // .........................e.'................................................................................. + // sqrdmulh v27.8h, v24.8h, v1.h[1] // ...........................'......*.......................................................................... + // mul v11.8h, v24.8h, v1.h[0] // ...........................'.......*......................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............*.................................................................. + // sub v24.8h, v12.8h, v13.8h // ...........................'.........*....................................................................... + // add v12.8h, v12.8h, v13.8h // ......................e....'..............................................................................~.. + // sqrdmulh v27.8h, v24.8h, v1.h[3] // ...........................'...............*................................................................. + // mul v13.8h, v24.8h, v1.h[2] // ...........................'................*................................................................ + // mls v13.8h, v27.8h, v7.h[0] // ...........................'....................*............................................................ + // sub v24.8h, v14.8h, v15.8h // ...........................'.................*............................................................... + // add v14.8h, v14.8h, v15.8h // .......................e...'...............................................................................~. + // sqrdmulh v27.8h, v24.8h, v1.h[5] // ...........................'.....................*........................................................... + // mul v15.8h, v24.8h, v1.h[4] // ...........................'.......................*......................................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................... + // sub v24.8h, v8.8h, v10.8h // ...........................'....*............................................................................ + // add v8.8h, v8.8h, v10.8h // ...........................'.....*........................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'..........*...................................................................... + // mul v10.8h, v24.8h, v0.h[2] // ...........................'...........*..................................................................... + // mls v10.8h, v27.8h, v7.h[0] // ...........................'......................*.......................................................... + // sub v24.8h, v9.8h, v11.8h // ...........................'..................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ...........................'...................*............................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'........................*........................................................ + // mul v11.8h, v24.8h, v0.h[2] // ...........................'.........................*....................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............................*.................................................. + // sub v24.8h, v12.8h, v14.8h // ...........................'...............................*................................................. + // add v12.8h, v12.8h, v14.8h // ..........................e'................................................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'..................................*.............................................. + // mul v14.8h, v24.8h, v0.h[4] // ...........................'...................................*............................................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'.......................................*......................................... + // sub v24.8h, v13.8h, v15.8h // ...........................'.................................*............................................... + // add v13.8h, v13.8h, v15.8h // ...........................'....................................*............................................ + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'.....................................*........................................... + // mul v15.8h, v24.8h, v0.h[4] // ...........................'......................................*.......................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'..........................................*...................................... + // sub v24.8h, v8.8h, v12.8h // ...........................'............*.................................................................... + // add v8.8h, v8.8h, v12.8h // ...........................'.............*................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'..........................*...................................................... + // mul v12.8h, v24.8h, v0.h[0] // ...........................'...........................*..................................................... + // mls v12.8h, v27.8h, v7.h[0] // ...........................'................................*................................................ + // sub v24.8h, v9.8h, v13.8h // ...........................'........................................*........................................ + // add v9.8h, v9.8h, v13.8h // ...........................'.........................................*....................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'...........................................*..................................... + // mul v13.8h, v24.8h, v0.h[0] // ...........................'............................................*.................................... + // mls v13.8h, v27.8h, v7.h[0] // ...........................'................................................*................................ + // sub v24.8h, v10.8h, v14.8h // ...........................'.............................................*................................... + // add v10.8h, v10.8h, v14.8h // ...........................'..............................................*.................................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'.................................................*............................... + // mul v14.8h, v24.8h, v0.h[0] // ...........................'..................................................*.............................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'......................................................*.......................... + // sub v24.8h, v11.8h, v15.8h // ...........................'...............................................*................................. + // add v11.8h, v11.8h, v15.8h // ...........................'...................................................*............................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'....................................................*............................ + // mul v15.8h, v24.8h, v0.h[0] // ...........................'.....................................................*........................... + // mls v15.8h, v27.8h, v7.h[0] // ..~........................'..........................................................*...................... + // str q12, [x0, #(4*(512/8))] // ...........................'.......................................................*......................... + // str q13, [x0, #(5*(512/8))] // ...~.......................'...........................................................*..................... + // str q14, [x0, #(6*(512/8))] // ......~....................'..............................................................*.................. + // str q15, [x0, #(7*(512/8))] // .........~.................'.................................................................*............... + // str q8, [x0], #(16) // ...........................'............................*.................................................... + // str q9, [x0, #(-16 + 1*(512/8))] // ............~..............'....................................................................*............ + // str q10, [x0, #(-16 + 2*(512/8))] // .....................~.....'.............................................................................*... + // str q11, [x0, #(-16 + 3*(512/8))] // ........................~..'................................................................................* + + sub count, count, #1 + cbnz count, layer012_start + // Instructions: 64 + // Expected cycles: 66 + // Expected IPC: 0.97 + // + // Cycle bound: 66.0 + // IPC bound: 0.97 + // + // Wall time: 8.33s + // User time: 8.33s + // + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + add v10.8H, v11.8H, v15.8H // *................................................................. + sub v12.8H, v28.8H, v13.8H // .*................................................................ + sub v11.8H, v11.8H, v15.8H // ..*............................................................... + sub v22.8H, v10.8H, v19.8H // ...*.............................................................. + mul v18.8H, v12.8H, v0.H[4] // ....*............................................................. + sqrdmulh v26.8H, v12.8H, v0.H[5] // .....*............................................................ + sqrdmulh v12.8H, v22.8H, v0.H[3] // ......*........................................................... + mul v13.8H, v22.8H, v0.H[2] // .......*.......................................................... + sub v31.8H, v24.8H, v16.8H // ........*......................................................... + sqrdmulh v22.8H, v11.8H, v0.H[7] // .........*........................................................ + mls v18.8H, v26.8H, v7.H[0] // ..........*....................................................... + mls v13.8H, v12.8H, v7.H[0] // ...........*...................................................... + sqrdmulh v2.8H, v31.8H, v1.H[1] // ............*..................................................... + mul v5.8H, v31.8H, v1.H[0] // .............*.................................................... + mul v15.8H, v11.8H, v0.H[6] // ..............*................................................... + sub v12.8H, v13.8H, v18.8H // ...............*.................................................. + sub v4.8H, v3.8H, v4.8H // ................*................................................. + mls v5.8H, v2.8H, v7.H[0] // .................*................................................ + sqrdmulh v26.8H, v12.8H, v0.H[1] // ..................*............................................... + mul v12.8H, v12.8H, v0.H[0] // ...................*.............................................. + mls v15.8H, v22.8H, v7.H[0] // ....................*............................................. + sqrdmulh v8.8H, v4.8H, v1.H[5] // .....................*............................................ + mul v4.8H, v4.8H, v1.H[4] // ......................*........................................... + mls v12.8H, v26.8H, v7.H[0] // .......................*.......................................... + sub v21.8H, v15.8H, v5.8H // ........................*......................................... + sub v28.8H, v9.8H, v6.8H // .........................*........................................ + mls v4.8H, v8.8H, v7.H[0] // ..........................*....................................... + mul v24.8H, v21.8H, v0.H[2] // ...........................*...................................... + sqrdmulh v8.8H, v21.8H, v0.H[3] // ............................*..................................... + sqrdmulh v6.8H, v28.8H, v1.H[3] // .............................*.................................... + add v19.8H, v10.8H, v19.8H // ..............................*................................... + mul v28.8H, v28.8H, v1.H[2] // ...............................*.................................. + mls v24.8H, v8.8H, v7.H[0] // ................................*................................. + sub v11.8H, v19.8H, v23.8H // .................................*................................ + str q12, [x0, #384] // ..................................*............................... + mls v28.8H, v6.8H, v7.H[0] // ...................................*.............................. + sqrdmulh v16.8H, v11.8H, v0.H[1] // ....................................*............................. + mul v9.8H, v11.8H, v0.H[0] // .....................................*............................ + add v6.8H, v15.8H, v5.8H // ......................................*........................... + add v26.8H, v28.8H, v4.8H // .......................................*.......................... + sub v15.8H, v28.8H, v4.8H // ........................................*......................... + mls v9.8H, v16.8H, v7.H[0] // .........................................*........................ + add v3.8H, v6.8H, v26.8H // ..........................................*....................... + mul v8.8H, v15.8H, v0.H[4] // ...........................................*...................... + sqrdmulh v15.8H, v15.8H, v0.H[5] // ............................................*..................... + str q9, [x0, #256] // .............................................*.................... + sub v2.8H, v6.8H, v26.8H // ..............................................*................... + str q3, [x0, #64] // ...............................................*.................. + mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + sqrdmulh v15.8H, v2.8H, v0.H[1] // .................................................*................ + mul v11.8H, v2.8H, v0.H[0] // ..................................................*............... + add v16.8H, v13.8H, v18.8H // ...................................................*.............. + sub v12.8H, v24.8H, v8.8H // ....................................................*............. + add v8.8H, v24.8H, v8.8H // .....................................................*............ + mls v11.8H, v15.8H, v7.H[0] // ......................................................*........... + sqrdmulh v26.8H, v12.8H, v0.H[1] // .......................................................*.......... + mul v12.8H, v12.8H, v0.H[0] // ........................................................*......... + str q8, [x0, #192] // .........................................................*........ + add v15.8H, v19.8H, v23.8H // ..........................................................*....... + str q11, [x0, #320] // ...........................................................*...... + mls v12.8H, v26.8H, v7.H[0] // ............................................................*..... + str q15, [x0], #(16) // .............................................................*.... + str q16, [x0, #112] // ...............................................................*.. + str q12, [x0, #432] // .................................................................* + + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + // sub v12.8H, v11.8H, v15.8H // ..*............................................................... + // add v26.8H, v11.8H, v15.8H // *................................................................. + // sub v8.8H, v24.8H, v16.8H // ........*......................................................... + // sqrdmulh v11.8H, v12.8H, v0.H[7] // .........*........................................................ + // mul v12.8H, v12.8H, v0.H[6] // ..............*................................................... + // sub v16.8H, v26.8H, v19.8H // ...*.............................................................. + // add v26.8H, v26.8H, v19.8H // ..............................*................................... + // sqrdmulh v15.8H, v8.8H, v1.H[1] // ............*..................................................... + // mul v8.8H, v8.8H, v1.H[0] // .............*.................................................... + // mls v12.8H, v11.8H, v7.H[0] // ....................*............................................. + // sub v11.8H, v9.8H, v6.8H // .........................*........................................ + // sqrdmulh v24.8H, v16.8H, v0.H[3] // ......*........................................................... + // mul v16.8H, v16.8H, v0.H[2] // .......*.......................................................... + // sub v9.8H, v26.8H, v23.8H // .................................*................................ + // add v26.8H, v26.8H, v23.8H // ..........................................................*....... + // mls v8.8H, v15.8H, v7.H[0] // .................*................................................ + // sqrdmulh v15.8H, v11.8H, v1.H[3] // .............................*.................................... + // mul v11.8H, v11.8H, v1.H[2] // ...............................*.................................. + // sub v6.8H, v3.8H, v4.8H // ................*................................................. + // sub v3.8H, v12.8H, v8.8H // ........................*......................................... + // add v12.8H, v12.8H, v8.8H // ......................................*........................... + // mls v11.8H, v15.8H, v7.H[0] // ...................................*.............................. + // sqrdmulh v8.8H, v6.8H, v1.H[5] // .....................*............................................ + // mls v16.8H, v24.8H, v7.H[0] // ...........*...................................................... + // mul v15.8H, v6.8H, v1.H[4] // ......................*........................................... + // sqrdmulh v24.8H, v3.8H, v0.H[3] // ............................*..................................... + // mul v6.8H, v3.8H, v0.H[2] // ...........................*...................................... + // sqrdmulh v3.8H, v9.8H, v0.H[1] // ....................................*............................. + // mul v9.8H, v9.8H, v0.H[0] // .....................................*............................ + // str q26, [x0], #(16) // .............................................................*.... + // mls v15.8H, v8.8H, v7.H[0] // ..........................*....................................... + // mls v6.8H, v24.8H, v7.H[0] // ................................*................................. + // sub v26.8H, v28.8H, v13.8H // .*................................................................ + // mls v9.8H, v3.8H, v7.H[0] // .........................................*........................ + // sub v8.8H, v11.8H, v15.8H // ........................................*......................... + // sqrdmulh v24.8H, v26.8H, v0.H[5] // .....*............................................................ + // mul v26.8H, v26.8H, v0.H[4] // ....*............................................................. + // add v11.8H, v11.8H, v15.8H // .......................................*.......................... + // sqrdmulh v15.8H, v8.8H, v0.H[5] // ............................................*..................... + // mul v8.8H, v8.8H, v0.H[4] // ...........................................*...................... + // mls v26.8H, v24.8H, v7.H[0] // ..........*....................................................... + // sub v24.8H, v12.8H, v11.8H // ..............................................*................... + // add v12.8H, v12.8H, v11.8H // ..........................................*....................... + // mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + // sqrdmulh v11.8H, v24.8H, v0.H[1] // .................................................*................ + // mul v15.8H, v24.8H, v0.H[0] // ..................................................*............... + // sub v24.8H, v16.8H, v26.8H // ...............*.................................................. + // add v26.8H, v16.8H, v26.8H // ...................................................*.............. + // sub v16.8H, v6.8H, v8.8H // ....................................................*............. + // mls v15.8H, v11.8H, v7.H[0] // ......................................................*........... + // sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................*............................................... + // mul v24.8H, v24.8H, v0.H[0] // ...................*.............................................. + // add v8.8H, v6.8H, v8.8H // .....................................................*............ + // sqrdmulh v6.8H, v16.8H, v0.H[1] // .......................................................*.......... + // mul v16.8H, v16.8H, v0.H[0] // ........................................................*......... + // mls v24.8H, v11.8H, v7.H[0] // .......................*.......................................... + // str q9, [x0, #240] // .............................................*.................... + // mls v16.8H, v6.8H, v7.H[0] // ............................................................*..... + // str q15, [x0, #304] // ...........................................................*...... + // str q24, [x0, #368] // ..................................*............................... + // str q16, [x0, #432] // .................................................................* + // str q12, [x0, #48] // ...............................................*.................. + // str q26, [x0, #112] // ...............................................................*.. + // str q8, [x0, #176] // .........................................................*........ + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_clean.S new file mode 100644 index 0000000000..877a5f689f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/ntt_clean.S @@ -0,0 +1,283 @@ +/// +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// Copyright (c) 2024 The mlkem-native project authors +// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro ct_butterfly a, b, root, idx0, idx1 + mulmodq tmp, \b, \root, \idx0, \idx1 + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro ct_butterfly_v a, b, root, root_twisted + mulmod tmp, \b, \root, \root_twisted + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + // Arguments + in .req x0 // Input/output buffer + r01234_ptr .req x1 // twiddles for layer 0,1,2,3,4 + r56_ptr .req x2 // twiddles for layer 5,6 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + consts .req v7 + q_consts .req q7 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + .text + .global MLKEM_ASM_NAMESPACE(ntt_asm_clean) + +/* Literal pool */ +.p2align 4 +c_consts: + .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + +MLKEM_ASM_NAMESPACE(ntt_asm_clean): + push_stack + ldr q_consts, c_consts + + mov inp, in + mov count, #4 + + load_roots_012 + + .p2align 2 + + // Bounds reasoning: + // - There are 7 layers + // - When passing from layer N to layer N+1, each layer-N value + // is modified through the addition/subtraction of a Montgomery + // product of a twiddle of absolute value < q/2 and a layer-N value. + // - Recalling that for C such that |a| < C * q and |t| + // 0 25 + // |------------------------|---- + ldr q21, [x0, #0] // *............................. + ldr q26, [x0, #64] // ..*........................... + ldr q29, [x0, #128] // ....*......................... + ldr q20, [x0, #192] // ......*....................... + ldr q23, [x0, #256] // ........*..................... + ldr q11, [x0, #448] // ..........*................... + mul v2.8H, v23.8H, v0.H[0] // ............*................. + ldr q17, [x0, #320] // .............*................ + mul v15.8H, v11.8H, v0.H[0] // ...............*.............. + ldr q13, [x0, #384] // ................*............. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q21, [x0, #0] // *.............................. + // ldr q26, [x0, #64] // ..*............................ + // ldr q29, [x0, #128] // ....*.......................... + // ldr q20, [x0, #192] // ......*........................ + // ldr q23, [x0, #256] // ........*...................... + // ldr q17, [x0, #320] // .............*................. + // mul v2.8H, v23.8H, v0.H[0] // ............*.................. + // ldr q11, [x0, #448] // ..........*.................... + // ldr q13, [x0, #384] // ................*.............. + // mul v15.8H, v11.8H, v0.H[0] // ...............*............... + + sub count, count, #1 +1: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.36s + // User time: 2.36s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sqrdmulh v14.8H, v23.8H, v0.H[1] // *................................................................................... + sqrdmulh v23.8H, v17.8H, v0.H[1] // .*.................................................................................. + mul v17.8H, v17.8H, v0.H[0] // ..*................................................................................. + sqrdmulh v28.8H, v13.8H, v0.H[1] // ...*................................................................................ + mls v2.8H, v14.8H, v7.H[0] // ....*............................................................................... + mul v14.8H, v13.8H, v0.H[0] // .....*.............................................................................. + mls v17.8H, v23.8H, v7.H[0] // ......*............................................................................. + sqrdmulh v23.8H, v11.8H, v0.H[1] // .......*............................................................................ + sub v11.8H, v21.8H, v2.8H // ........*........................................................................... + mls v14.8H, v28.8H, v7.H[0] // .........*.......................................................................... + sub v28.8H, v26.8H, v17.8H // ..........*......................................................................... + add v17.8H, v26.8H, v17.8H // ...........*........................................................................ + add v2.8H, v21.8H, v2.8H // ............*....................................................................... + sub v13.8H, v29.8H, v14.8H // .............*...................................................................... + add v14.8H, v29.8H, v14.8H // ..............*..................................................................... + mls v15.8H, v23.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v23.8H, v13.8H, v0.H[5] // ................*................................................................... + mul v13.8H, v13.8H, v0.H[4] // .................*.................................................................. + sqrdmulh v21.8H, v14.8H, v0.H[3] // ..................*................................................................. + sub v26.8H, v20.8H, v15.8H // ...................*................................................................ + add v15.8H, v20.8H, v15.8H // ....................*............................................................... + mls v13.8H, v23.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v23.8H, v26.8H, v0.H[5] // ......................*............................................................. + mul v26.8H, v26.8H, v0.H[4] // .......................*............................................................ + mul v14.8H, v14.8H, v0.H[2] // ........................*........................................................... + sub v29.8H, v11.8H, v13.8H // .........................*.......................................................... + add v11.8H, v11.8H, v13.8H // ..........................*......................................................... + mls v26.8H, v23.8H, v7.H[0] // ...........................*........................................................ + sqrdmulh v23.8H, v15.8H, v0.H[3] // ............................*....................................................... + mul v13.8H, v15.8H, v0.H[2] // .............................*...................................................... + mls v14.8H, v21.8H, v7.H[0] // ..............................*..................................................... + sub v15.8H, v28.8H, v26.8H // ...............................*.................................................... + add v28.8H, v28.8H, v26.8H // ................................*................................................... + mls v13.8H, v23.8H, v7.H[0] // .................................*.................................................. + sub v23.8H, v2.8H, v14.8H // ..................................*................................................. + add v14.8H, v2.8H, v14.8H // ...................................*................................................ + sqrdmulh v2.8H, v28.8H, v1.H[3] // ....................................*............................................... + sub v21.8H, v17.8H, v13.8H // .....................................*.............................................. + add v17.8H, v17.8H, v13.8H // ......................................*............................................. + mul v28.8H, v28.8H, v1.H[2] // .......................................*............................................ + sqrdmulh v13.8H, v21.8H, v1.H[1] // ........................................*........................................... + sqrdmulh v26.8H, v17.8H, v0.H[7] // .........................................*.......................................... + mul v17.8H, v17.8H, v0.H[6] // ..........................................*......................................... + mul v21.8H, v21.8H, v1.H[0] // ...........................................*........................................ + mls v28.8H, v2.8H, v7.H[0] // ............................................*....................................... + sqrdmulh v2.8H, v15.8H, v1.H[5] // .............................................*...................................... + mls v17.8H, v26.8H, v7.H[0] // ..............................................*..................................... + mls v21.8H, v13.8H, v7.H[0] // ...............................................*.................................... + sub v13.8H, v11.8H, v28.8H // ................................................*................................... + add v28.8H, v11.8H, v28.8H // .................................................*.................................. + sub v11.8H, v14.8H, v17.8H // ..................................................*................................. + mul v15.8H, v15.8H, v1.H[4] // ...................................................*................................ + add v14.8H, v14.8H, v17.8H // ....................................................*............................... + sub v17.8H, v23.8H, v21.8H // .....................................................*.............................. + add v23.8H, v23.8H, v21.8H // ......................................................*............................. + mls v15.8H, v2.8H, v7.H[0] // .......................................................*............................ + str q14, [x0], #(16) // ........................................................*........................... + ldr q21, [x0, #0] // .........................................................e.......................... + sub v14.8H, v29.8H, v15.8H // ...........................................................*........................ + add v2.8H, v29.8H, v15.8H // ............................................................*....................... + str q11, [x0, #48] // .............................................................*...................... + ldr q26, [x0, #64] // ..............................................................e..................... + str q23, [x0, #112] // ................................................................*................... + ldr q29, [x0, #128] // .................................................................e.................. + str q17, [x0, #176] // ...................................................................*................ + ldr q20, [x0, #192] // ....................................................................e............... + str q28, [x0, #240] // ......................................................................*............. + ldr q23, [x0, #256] // .......................................................................e............ + str q13, [x0, #304] // .........................................................................*.......... + ldr q17, [x0, #320] // ..........................................................................e......... + str q2, [x0, #368] // ............................................................................*....... + mul v2.8H, v23.8H, v0.H[0] // .............................................................................e...... + str q14, [x0, #432] // ..............................................................................*..... + ldr q11, [x0, #448] // ...............................................................................e.... + ldr q13, [x0, #384] // .................................................................................e.. + mul v15.8H, v11.8H, v0.H[0] // ...................................................................................e + + // ------------------------------------------- cycle (expected) --------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|----- + // ldr q8, [x0, #0] // e..........................'........................................................~..................... + // ldr q9, [x0, #(1*(512/8))] // .....e.....................'.............................................................~................ + // ldr q10, [x0, #(2*(512/8))] // ........e..................'................................................................~............. + // ldr q11, [x0, #(3*(512/8))] // ...........e...............'...................................................................~.......... + // ldr q12, [x0, #(4*(512/8))] // ..............e............'......................................................................~....... + // ldr q13, [x0, #(5*(512/8))] // .................e.........'.........................................................................~.... + // ldr q14, [x0, #(6*(512/8))] // ........................e..'.............................................................................. + // ldr q15, [x0, #(7*(512/8))] // ......................e....'.............................................................................. + // sqrdmulh v27.8h, v12.8h, v0.h[1] // ...........................*.............................................................................. + // mul v24.8h, v12.8h, v0.h[0] // ....................e......'............................................................................~. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...*.......................................................................... + // sub v12.8h, v8.8h, v24.8h // ...........................'.......*...................................................................... + // add v8.8h, v8.8h, v24.8h // ...........................'...........*.................................................................. + // sqrdmulh v27.8h, v13.8h, v0.h[1] // ...........................'*............................................................................. + // mul v24.8h, v13.8h, v0.h[0] // ...........................'.*............................................................................ + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.....*........................................................................ + // sub v13.8h, v9.8h, v24.8h // ...........................'.........*.................................................................... + // add v9.8h, v9.8h, v24.8h // ...........................'..........*................................................................... + // sqrdmulh v27.8h, v14.8h, v0.h[1] // ...........................'..*........................................................................... + // mul v24.8h, v14.8h, v0.h[0] // ...........................'....*......................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'........*..................................................................... + // sub v14.8h, v10.8h, v24.8h // ...........................'............*................................................................. + // add v10.8h, v10.8h, v24.8h // ...........................'.............*................................................................ + // sqrdmulh v27.8h, v15.8h, v0.h[1] // ...........................'......*....................................................................... + // mul v24.8h, v15.8h, v0.h[0] // ..........................e'.............................................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............*............................................................... + // sub v15.8h, v11.8h, v24.8h // ...........................'..................*........................................................... + // add v11.8h, v11.8h, v24.8h // ...........................'...................*.......................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[3] // ...........................'.................*............................................................ + // mul v24.8h, v10.8h, v0.h[2] // ...........................'.......................*...................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................ + // sub v10.8h, v8.8h, v24.8h // ...........................'.................................*............................................ + // add v8.8h, v8.8h, v24.8h // ...........................'..................................*........................................... + // sqrdmulh v27.8h, v11.8h, v0.h[3] // ...........................'...........................*.................................................. + // mul v24.8h, v11.8h, v0.h[2] // ...........................'............................*................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'................................*............................................. + // sub v11.8h, v9.8h, v24.8h // ...........................'....................................*......................................... + // add v9.8h, v9.8h, v24.8h // ...........................'.....................................*........................................ + // sqrdmulh v27.8h, v14.8h, v0.h[5] // ...........................'...............*.............................................................. + // mul v24.8h, v14.8h, v0.h[4] // ...........................'................*............................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'....................*......................................................... + // sub v14.8h, v12.8h, v24.8h // ...........................'........................*..................................................... + // add v12.8h, v12.8h, v24.8h // ...........................'.........................*.................................................... + // sqrdmulh v27.8h, v15.8h, v0.h[5] // ...........................'.....................*........................................................ + // mul v24.8h, v15.8h, v0.h[4] // ...........................'......................*....................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..........................*................................................... + // sub v15.8h, v13.8h, v24.8h // ...........................'..............................*............................................... + // add v13.8h, v13.8h, v24.8h // ...........................'...............................*.............................................. + // sqrdmulh v27.8h, v9.8h, v0.h[7] // ...........................'........................................*..................................... + // mul v24.8h, v9.8h, v0.h[6] // ...........................'.........................................*.................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................................*................................ + // sub v9.8h, v8.8h, v24.8h // ...........................'.................................................*............................ + // add v8.8h, v8.8h, v24.8h // ...........................'...................................................*.......................... + // sqrdmulh v27.8h, v11.8h, v1.h[1] // ...........................'.......................................*...................................... + // mul v24.8h, v11.8h, v1.h[0] // ...........................'..........................................*................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............................................*............................... + // sub v11.8h, v10.8h, v24.8h // ...........................'....................................................*......................... + // add v10.8h, v10.8h, v24.8h // ...........................'.....................................................*........................ + // sqrdmulh v27.8h, v13.8h, v1.h[3] // ...........................'...................................*.......................................... + // mul v24.8h, v13.8h, v1.h[2] // ...........................'......................................*....................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...........................................*.................................. + // sub v13.8h, v12.8h, v24.8h // ...........................'...............................................*.............................. + // add v12.8h, v12.8h, v24.8h // ...........................'................................................*............................. + // sqrdmulh v27.8h, v15.8h, v1.h[5] // ...........................'............................................*................................. + // mul v24.8h, v15.8h, v1.h[4] // ...........................'..................................................*........................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'......................................................*....................... + // sub v15.8h, v14.8h, v24.8h // ..~........................'..........................................................*................... + // add v14.8h, v14.8h, v24.8h // ...~.......................'...........................................................*.................. + // str q8, [x0], #(16) // ...........................'.......................................................*...................... + // str q9, [x0, #(-16 + 1*(512/8))] // ....~......................'............................................................*................. + // str q10, [x0, #(-16 + 2*(512/8))] // .......~...................'...............................................................*.............. + // str q11, [x0, #(-16 + 3*(512/8))] // ..........~................'..................................................................*........... + // str q12, [x0, #(-16 + 4*(512/8))] // .............~.............'.....................................................................*........ + // str q13, [x0, #(-16 + 5*(512/8))] // ................~..........'........................................................................*..... + // str q14, [x0, #(-16 + 6*(512/8))] // ...................~.......'...........................................................................*.. + // str q15, [x0, #(-16 + 7*(512/8))] // .....................~.....'.............................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 66 + // Expected cycles: 67 + // Expected IPC: 0.99 + // + // Cycle bound: 67.0 + // IPC bound: 0.99 + // + // Wall time: 7.51s + // User time: 7.51s + // + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + sqrdmulh v27.8H, v11.8H, v0.H[1] // *.................................................................. + mul v8.8H, v13.8H, v0.H[0] // .*................................................................. + sqrdmulh v22.8H, v13.8H, v0.H[1] // ..*................................................................ + mul v11.8H, v17.8H, v0.H[0] // ...*............................................................... + mls v15.8H, v27.8H, v7.H[0] // ....*.............................................................. + sqrdmulh v28.8H, v17.8H, v0.H[1] // .....*............................................................. + mls v8.8H, v22.8H, v7.H[0] // ......*............................................................ + sqrdmulh v5.8H, v23.8H, v0.H[1] // .......*........................................................... + add v16.8H, v20.8H, v15.8H // ........*.......................................................... + mls v11.8H, v28.8H, v7.H[0] // .........*......................................................... + sub v6.8H, v29.8H, v8.8H // ..........*........................................................ + sqrdmulh v17.8H, v16.8H, v0.H[3] // ...........*....................................................... + mul v23.8H, v16.8H, v0.H[2] // ............*...................................................... + mul v13.8H, v6.8H, v0.H[4] // .............*..................................................... + sqrdmulh v28.8H, v6.8H, v0.H[5] // ..............*.................................................... + mls v2.8H, v5.8H, v7.H[0] // ...............*................................................... + mls v23.8H, v17.8H, v7.H[0] // ................*.................................................. + add v27.8H, v26.8H, v11.8H // .................*................................................. + mls v13.8H, v28.8H, v7.H[0] // ..................*................................................ + sub v9.8H, v21.8H, v2.8H // ...................*............................................... + add v18.8H, v29.8H, v8.8H // ....................*.............................................. + sub v14.8H, v27.8H, v23.8H // .....................*............................................. + add v29.8H, v9.8H, v13.8H // ......................*............................................ + sub v30.8H, v9.8H, v13.8H // .......................*........................................... + mul v28.8H, v14.8H, v1.H[0] // ........................*.......................................... + sqrdmulh v9.8H, v18.8H, v0.H[3] // .........................*......................................... + mul v22.8H, v18.8H, v0.H[2] // ..........................*........................................ + sqrdmulh v17.8H, v14.8H, v1.H[1] // ...........................*....................................... + sub v14.8H, v20.8H, v15.8H // ............................*...................................... + add v24.8H, v21.8H, v2.8H // .............................*..................................... + mls v22.8H, v9.8H, v7.H[0] // ..............................*.................................... + sqrdmulh v9.8H, v14.8H, v0.H[5] // ...............................*................................... + mul v13.8H, v14.8H, v0.H[4] // ................................*.................................. + mls v28.8H, v17.8H, v7.H[0] // .................................*................................. + sub v5.8H, v24.8H, v22.8H // ..................................*................................ + sub v2.8H, v26.8H, v11.8H // ...................................*............................... + mls v13.8H, v9.8H, v7.H[0] // ....................................*.............................. + sub v17.8H, v5.8H, v28.8H // .....................................*............................. + add v14.8H, v5.8H, v28.8H // ......................................*............................ + add v28.8H, v27.8H, v23.8H // .......................................*........................... + str q17, [x0, #192] // ........................................*.......................... + add v17.8H, v2.8H, v13.8H // .........................................*......................... + str q14, [x0, #128] // ..........................................*........................ + sub v13.8H, v2.8H, v13.8H // ...........................................*....................... + sqrdmulh v26.8H, v17.8H, v1.H[3] // ............................................*...................... + mul v15.8H, v17.8H, v1.H[2] // .............................................*..................... + add v5.8H, v24.8H, v22.8H // ..............................................*.................... + sqrdmulh v23.8H, v13.8H, v1.H[5] // ...............................................*................... + mul v13.8H, v13.8H, v1.H[4] // ................................................*.................. + mls v15.8H, v26.8H, v7.H[0] // .................................................*................. + sqrdmulh v14.8H, v28.8H, v0.H[7] // ..................................................*................ + mul v17.8H, v28.8H, v0.H[6] // ...................................................*............... + mls v13.8H, v23.8H, v7.H[0] // ....................................................*.............. + add v6.8H, v29.8H, v15.8H // .....................................................*............. + sub v28.8H, v29.8H, v15.8H // ......................................................*............ + mls v17.8H, v14.8H, v7.H[0] // .......................................................*........... + str q6, [x0, #256] // ........................................................*.......... + add v14.8H, v30.8H, v13.8H // .........................................................*......... + str q28, [x0, #320] // ..........................................................*........ + sub v23.8H, v30.8H, v13.8H // ...........................................................*....... + str q14, [x0, #384] // ............................................................*...... + add v3.8H, v5.8H, v17.8H // .............................................................*..... + str q23, [x0, #448] // ..............................................................*.... + sub v28.8H, v5.8H, v17.8H // ...............................................................*... + str q3, [x0], #(16) // ................................................................*.. + str q28, [x0, #48] // ..................................................................* + + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + // sqrdmulh v14.8H, v23.8H, v0.H[1] // .......*........................................................... + // sqrdmulh v23.8H, v17.8H, v0.H[1] // .....*............................................................. + // mul v17.8H, v17.8H, v0.H[0] // ...*............................................................... + // sqrdmulh v28.8H, v13.8H, v0.H[1] // ..*................................................................ + // mls v2.8H, v14.8H, v7.H[0] // ...............*................................................... + // mul v14.8H, v13.8H, v0.H[0] // .*................................................................. + // mls v17.8H, v23.8H, v7.H[0] // .........*......................................................... + // sqrdmulh v23.8H, v11.8H, v0.H[1] // *.................................................................. + // sub v11.8H, v21.8H, v2.8H // ...................*............................................... + // mls v14.8H, v28.8H, v7.H[0] // ......*............................................................ + // sub v28.8H, v26.8H, v17.8H // ...................................*............................... + // add v17.8H, v26.8H, v17.8H // .................*................................................. + // add v2.8H, v21.8H, v2.8H // .............................*..................................... + // sub v13.8H, v29.8H, v14.8H // ..........*........................................................ + // add v14.8H, v29.8H, v14.8H // ....................*.............................................. + // mls v15.8H, v23.8H, v7.H[0] // ....*.............................................................. + // sqrdmulh v23.8H, v13.8H, v0.H[5] // ..............*.................................................... + // mul v13.8H, v13.8H, v0.H[4] // .............*..................................................... + // sqrdmulh v21.8H, v14.8H, v0.H[3] // .........................*......................................... + // sub v26.8H, v20.8H, v15.8H // ............................*...................................... + // add v15.8H, v20.8H, v15.8H // ........*.......................................................... + // mls v13.8H, v23.8H, v7.H[0] // ..................*................................................ + // sqrdmulh v23.8H, v26.8H, v0.H[5] // ...............................*................................... + // mul v26.8H, v26.8H, v0.H[4] // ................................*.................................. + // mul v14.8H, v14.8H, v0.H[2] // ..........................*........................................ + // sub v29.8H, v11.8H, v13.8H // .......................*........................................... + // add v11.8H, v11.8H, v13.8H // ......................*............................................ + // mls v26.8H, v23.8H, v7.H[0] // ....................................*.............................. + // sqrdmulh v23.8H, v15.8H, v0.H[3] // ...........*....................................................... + // mul v13.8H, v15.8H, v0.H[2] // ............*...................................................... + // mls v14.8H, v21.8H, v7.H[0] // ..............................*.................................... + // sub v15.8H, v28.8H, v26.8H // ...........................................*....................... + // add v28.8H, v28.8H, v26.8H // .........................................*......................... + // mls v13.8H, v23.8H, v7.H[0] // ................*.................................................. + // sub v23.8H, v2.8H, v14.8H // ..................................*................................ + // add v14.8H, v2.8H, v14.8H // ..............................................*.................... + // sqrdmulh v2.8H, v28.8H, v1.H[3] // ............................................*...................... + // sub v21.8H, v17.8H, v13.8H // .....................*............................................. + // add v17.8H, v17.8H, v13.8H // .......................................*........................... + // mul v28.8H, v28.8H, v1.H[2] // .............................................*..................... + // sqrdmulh v13.8H, v21.8H, v1.H[1] // ...........................*....................................... + // sqrdmulh v26.8H, v17.8H, v0.H[7] // ..................................................*................ + // mul v17.8H, v17.8H, v0.H[6] // ...................................................*............... + // mul v21.8H, v21.8H, v1.H[0] // ........................*.......................................... + // mls v28.8H, v2.8H, v7.H[0] // .................................................*................. + // sqrdmulh v2.8H, v15.8H, v1.H[5] // ...............................................*................... + // mls v17.8H, v26.8H, v7.H[0] // .......................................................*........... + // mls v21.8H, v13.8H, v7.H[0] // .................................*................................. + // sub v13.8H, v11.8H, v28.8H // ......................................................*............ + // add v28.8H, v11.8H, v28.8H // .....................................................*............. + // sub v11.8H, v14.8H, v17.8H // ...............................................................*... + // mul v15.8H, v15.8H, v1.H[4] // ................................................*.................. + // add v14.8H, v14.8H, v17.8H // .............................................................*..... + // sub v17.8H, v23.8H, v21.8H // .....................................*............................. + // add v23.8H, v23.8H, v21.8H // ......................................*............................ + // mls v15.8H, v2.8H, v7.H[0] // ....................................................*.............. + // str q14, [x0], #(16) // ................................................................*.. + // sub v14.8H, v29.8H, v15.8H // ...........................................................*....... + // add v2.8H, v29.8H, v15.8H // .........................................................*......... + // str q11, [x0, #48] // ..................................................................* + // str q23, [x0, #112] // ..........................................*........................ + // str q17, [x0, #176] // ........................................*.......................... + // str q28, [x0, #240] // ........................................................*.......... + // str q13, [x0, #304] // ..........................................................*........ + // str q2, [x0, #368] // ............................................................*...... + // str q14, [x0, #432] // ..............................................................*.... + + + mov in, inp + mov count, #8 + + .p2align 2 + // Instructions: 24 + // Expected cycles: 31 + // Expected IPC: 0.77 + // + // Cycle bound: 31.0 + // IPC bound: 0.77 + // + // Wall time: 0.08s + // User time: 0.08s + // + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + ldr q2, [x1], #16 // *.............................. + ldr q14, [x0, #48] // ..*............................ + ldr q1, [x0, #32] // ....*.......................... + mul v17.8H, v14.8H, v2.H[0] // ......*........................ + sqrdmulh v14.8H, v14.8H, v2.H[1] // .......*....................... + mul v8.8H, v1.8H, v2.H[0] // ........*...................... + ldr q23, [x0, #16] // .........*..................... + mls v17.8H, v14.8H, v7.H[0] // ...........*................... + sqrdmulh v1.8H, v1.8H, v2.H[1] // ............*.................. + ldr q30, [x2], #(6*16) // .............*................. + sub v14.8H, v23.8H, v17.8H // ...............*............... + add v10.8H, v23.8H, v17.8H // ................*.............. + mls v8.8H, v1.8H, v7.H[0] // .................*............. + sqrdmulh v1.8H, v14.8H, v2.H[5] // ..................*............ + mul v14.8H, v14.8H, v2.H[4] // ...................*........... + ldr q27, [x0, #0] // ....................*.......... + mul v23.8H, v10.8H, v2.H[2] // ......................*........ + mls v14.8H, v1.8H, v7.H[0] // .......................*....... + sub v1.8H, v27.8H, v8.8H // ........................*...... + ldr q28, [x2, #-64] // .........................*..... + add v12.8H, v1.8H, v14.8H // ...........................*... + sqrdmulh v21.8H, v10.8H, v2.H[3] // ............................*.. + sub v5.8H, v1.8H, v14.8H // .............................*. + ldr q13, [x2, #-16] // ..............................* + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q19, [x0, #48] // ..*............................ + // ldr q1, [x1], #16 // *.............................. + // mul v4.8H, v19.8H, v1.H[0] // ......*........................ + // sqrdmulh v19.8H, v19.8H, v1.H[1] // .......*....................... + // ldr q25, [x0, #16] // .........*..................... + // mls v4.8H, v19.8H, v7.H[0] // ...........*................... + // sub v24.8H, v25.8H, v4.8H // ...............*............... + // add v4.8H, v25.8H, v4.8H // ................*.............. + // sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................*............ + // mul v20.8H, v24.8H, v1.H[4] // ...................*........... + // sqrdmulh v21.8H, v4.8H, v1.H[3] // ............................*.. + // mls v20.8H, v23.8H, v7.H[0] // .......................*....... + // mul v23.8H, v4.8H, v1.H[2] // ......................*........ + // ldr q31, [x0, #32] // ....*.......................... + // mul v8.8H, v31.8H, v1.H[0] // ........*...................... + // sqrdmulh v1.8H, v31.8H, v1.H[1] // ............*.................. + // mls v8.8H, v1.8H, v7.H[0] // .................*............. + // ldr q27, [x0, #0] // ....................*.......... + // sub v10.8H, v27.8H, v8.8H // ........................*...... + // add v12.8H, v10.8H, v20.8H // ...........................*... + // ldr q30, [x2], #(6*16) // .............*................. + // ldr q28, [x2, #-64] // .........................*..... + // sub v5.8H, v10.8H, v20.8H // .............................*. + // ldr q13, [x2, #-16] // ..............................* + + sub count, count, #1 +1: + // Instructions: 71 + // Expected cycles: 82 + // Expected IPC: 0.87 + // + // Cycle bound: 82.0 + // IPC bound: 0.87 + // + // Wall time: 11.93s + // User time: 11.93s + // + // ------------------------------- cycle (expected) --------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + ldr q19, [x0, #112] // e................................................................................. + ldr q1, [x1], #16 // ..e............................................................................... + mls v23.8H, v21.8H, v7.H[0] // ....*............................................................................. + add v6.8H, v27.8H, v8.8H // .....*............................................................................ + mul v4.8H, v19.8H, v1.H[0] // ......e........................................................................... + sqrdmulh v19.8H, v19.8H, v1.H[1] // .......e.......................................................................... + ldr q25, [x0, #80] // ........e......................................................................... + trn1 v11.4S, v12.4S, v5.4S // ..........*....................................................................... + mls v4.8H, v19.8H, v7.H[0] // ...........e...................................................................... + sub v0.8H, v6.8H, v23.8H // ............*..................................................................... + ldr q16, [x2, #-80] // .............*.................................................................... + sub v24.8H, v25.8H, v4.8H // ...............e.................................................................. + add v26.8H, v6.8H, v23.8H // ................*................................................................. + add v4.8H, v25.8H, v4.8H // .................e................................................................ + sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................e............................................................... + mul v20.8H, v24.8H, v1.H[4] // ...................e.............................................................. + sqrdmulh v21.8H, v4.8H, v1.H[3] // ....................e............................................................. + trn1 v27.4S, v26.4S, v0.4S // .....................*............................................................ + trn2 v25.4S, v12.4S, v5.4S // ......................*........................................................... + mls v20.8H, v23.8H, v7.H[0] // .......................e.......................................................... + mul v23.8H, v4.8H, v1.H[2] // ........................e......................................................... + ldr q31, [x0, #96] // .........................e........................................................ + trn2 v12.4S, v26.4S, v0.4S // ...........................*...................................................... + trn2 v19.2D, v27.2D, v11.2D // ............................*..................................................... + mul v8.8H, v31.8H, v1.H[0] // .............................e.................................................... + sqrdmulh v1.8H, v31.8H, v1.H[1] // ..............................e................................................... + trn2 v10.2D, v12.2D, v25.2D // ...............................*.................................................. + sqrdmulh v0.8H, v19.8H, v16.8H // ................................*................................................. + sqrdmulh v18.8H, v10.8H, v16.8H // .................................*................................................ + trn1 v16.2D, v27.2D, v11.2D // ..................................*............................................... + trn1 v2.2D, v12.2D, v25.2D // ...................................*.............................................. + mul v12.8H, v10.8H, v30.8H // ....................................*............................................. + mul v10.8H, v19.8H, v30.8H // .....................................*............................................ + mls v8.8H, v1.8H, v7.H[0] // ......................................e........................................... + ldr q14, [x2, #-48] // .......................................*.......................................... + mls v10.8H, v0.8H, v7.H[0] // .........................................*........................................ + mls v12.8H, v18.8H, v7.H[0] // ..........................................*....................................... + ldr q27, [x0, #64] // ...........................................e...................................... + add v9.8H, v16.8H, v10.8H // .............................................*.................................... + sub v16.8H, v16.8H, v10.8H // ..............................................*................................... + sub v25.8H, v2.8H, v12.8H // ...............................................*.................................. + add v30.8H, v2.8H, v12.8H // ................................................*................................. + sub v10.8H, v27.8H, v8.8H // .................................................e................................ + sqrdmulh v22.8H, v25.8H, v13.8H // ..................................................*............................... + sqrdmulh v13.8H, v30.8H, v14.8H // ...................................................*.............................. + ldr q14, [x2, #-32] // ....................................................*............................. + add v12.8H, v10.8H, v20.8H // ......................................................e........................... + mul v5.8H, v30.8H, v28.8H // .......................................................*.......................... + mul v26.8H, v25.8H, v14.8H // ........................................................*......................... + ldr q30, [x2], #(6*16) // .........................................................e........................ + mls v5.8H, v13.8H, v7.H[0] // ...........................................................*...................... + mls v26.8H, v22.8H, v7.H[0] // ............................................................*..................... + ldr q28, [x2, #-64] // .............................................................e.................... + add v13.8H, v9.8H, v5.8H // ...............................................................*.................. + sub v9.8H, v9.8H, v5.8H // ................................................................*................. + sub v5.8H, v16.8H, v26.8H // .................................................................*................ + add v25.8H, v16.8H, v26.8H // ..................................................................*............... + trn1 v15.4S, v13.4S, v9.4S // ...................................................................*.............. + trn2 v3.4S, v13.4S, v9.4S // ....................................................................*............. + trn1 v13.4S, v25.4S, v5.4S // .....................................................................*............ + trn2 v31.4S, v25.4S, v5.4S // ......................................................................*........... + sub v5.8H, v10.8H, v20.8H // .......................................................................e.......... + trn1 v2.2D, v15.2D, v13.2D // ........................................................................*......... + trn2 v9.2D, v15.2D, v13.2D // .........................................................................*........ + str q2, [x0], #(16*4) // ..........................................................................*....... + trn1 v29.2D, v3.2D, v31.2D // ...........................................................................*...... + str q9, [x0, #-32] // ............................................................................*..... + trn2 v9.2D, v3.2D, v31.2D // .............................................................................*.... + str q29, [x0, #-48] // ..............................................................................*... + ldr q13, [x2, #-16] // ...............................................................................e.. + str q9, [x0, #-16] // .................................................................................* + + // ------------------------------------------------------------------------ cycle (expected) -------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------- + // ldr q8, [x0, #(16*0)] // ...........................................e......................................'..........................................~...................................... + // ldr q9, [x0, #(16*1)] // ........e.........................................................................'.......~......................................................................... + // ldr q10, [x0, #(16*2)] // .........................e........................................................'........................~........................................................ + // ldr q11, [x0, #(16*3)] // e.................................................................................~................................................................................. + // ldr q0, [x1], #16 // ..e...............................................................................'.~............................................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[1] // ..............................e...................................................'.............................~................................................... + // mul v24.8h, v10.8h, v0.h[0] // .............................e....................................................'............................~.................................................... + // mls v24.8h, v27.8h, v7.h[0] // ......................................e...........................................'.....................................~........................................... + // sub v10.8h, v8.8h, v24.8h // .................................................e................................'................................................~................................ + // add v8.8h, v8.8h, v24.8h // .....~............................................................................'....*............................................................................ + // sqrdmulh v27.8h, v11.8h, v0.h[1] // .......e..........................................................................'......~.......................................................................... + // mul v24.8h, v11.8h, v0.h[0] // ......e...........................................................................'.....~........................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........e......................................................................'..........~...................................................................... + // sub v11.8h, v9.8h, v24.8h // ...............e..................................................................'..............~.................................................................. + // add v9.8h, v9.8h, v24.8h // .................e................................................................'................~................................................................ + // sqrdmulh v27.8h, v9.8h, v0.h[3] // ....................e.............................................................'...................~............................................................. + // mul v24.8h, v9.8h, v0.h[2] // ........................e.........................................................'.......................~......................................................... + // mls v24.8h, v27.8h, v7.h[0] // ....~.............................................................................'...*............................................................................. + // sub v9.8h, v8.8h, v24.8h // ............~.....................................................................'...........*..................................................................... + // add v8.8h, v8.8h, v24.8h // ................~.................................................................'...............*................................................................. + // sqrdmulh v27.8h, v11.8h, v0.h[5] // ..................e...............................................................'.................~............................................................... + // mul v24.8h, v11.8h, v0.h[4] // ...................e..............................................................'..................~.............................................................. + // mls v24.8h, v27.8h, v7.h[0] // .......................e..........................................................'......................~.......................................................... + // sub v11.8h, v10.8h, v24.8h // .......................................................................e..........'......................................................................~.......... + // add v10.8h, v10.8h, v24.8h // ......................................................e...........................'.....................................................~........................... + // trn1 v25.4s, v8.4s, v9.4s // .....................~............................................................'....................*............................................................ + // trn2 v26.4s, v8.4s, v9.4s // ...........................~......................................................'..........................*...................................................... + // trn1 v27.4s, v10.4s, v11.4s // ..........~.......................................................................'.........*....................................................................... + // trn2 v28.4s, v10.4s, v11.4s // ......................~...........................................................'.....................*........................................................... + // trn2 v10.2d, v25.2d, v27.2d // ............................~.....................................................'...........................*..................................................... + // trn2 v11.2d, v26.2d, v28.2d // ...............................~..................................................'..............................*.................................................. + // trn1 v8.2d, v25.2d, v27.2d // ..................................~...............................................'.................................*............................................... + // trn1 v9.2d, v26.2d, v28.2d // ...................................~..............................................'..................................*.............................................. + // ldr q0, [x2], #(6*16) // .........................................................e........................'........................................................~........................ + // ldr q4, [x2, #(-6*16 + 1*16)] // .............~....................................................................'............*.................................................................... + // ldr q1, [x2, #(-6*16 + 2*16)] // .............................................................e....................'............................................................~.................... + // ldr q5, [x2, #(-6*16 + 3*16)] // .......................................~..........................................'......................................*.......................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ....................................................~.............................'...................................................*............................. + // ldr q6, [x2, #(-6*16 + 5*16)] // ...............................................................................e..'..............................................................................~.. + // sqrdmulh v27.8h, v10.8h, v4.8h // ................................~.................................................'...............................*................................................. + // mul v24.8h, v10.8h, v0.8h // .....................................~............................................'....................................*............................................ + // mls v24.8h, v27.8h, v7.h[0] // .........................................~........................................'........................................*........................................ + // sub v10.8h, v8.8h, v24.8h // ..............................................~...................................'.............................................*................................... + // add v8.8h, v8.8h, v24.8h // .............................................~....................................'............................................*.................................... + // sqrdmulh v27.8h, v11.8h, v4.8h // .................................~................................................'................................*................................................ + // mul v24.8h, v11.8h, v0.8h // ....................................~.............................................'...................................*............................................. + // mls v24.8h, v27.8h, v7.h[0] // ..........................................~.......................................'.........................................*....................................... + // sub v11.8h, v9.8h, v24.8h // ...............................................~..................................'..............................................*.................................. + // add v9.8h, v9.8h, v24.8h // ................................................~.................................'...............................................*................................. + // sqrdmulh v27.8h, v9.8h, v5.8h // ...................................................~..............................'..................................................*.............................. + // mul v24.8h, v9.8h, v1.8h // .......................................................~..........................'......................................................*.......................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................................................~......................'..........................................................*...................... + // sub v9.8h, v8.8h, v24.8h // ................................................................~.................'...............................................................*................. + // add v8.8h, v8.8h, v24.8h // ...............................................................~..................'..............................................................*.................. + // sqrdmulh v27.8h, v11.8h, v6.8h // ..................................................~...............................'.................................................*............................... + // mul v24.8h, v11.8h, v2.8h // ........................................................~.........................'.......................................................*......................... + // mls v24.8h, v27.8h, v7.h[0] // ............................................................~.....................'...........................................................*..................... + // sub v11.8h, v10.8h, v24.8h // .................................................................~................'................................................................*................ + // add v10.8h, v10.8h, v24.8h // ..................................................................~...............'.................................................................*............... + // trn1 v25.4s, v8.4s, v9.4s // ...................................................................~..............'..................................................................*.............. + // trn2 v26.4s, v8.4s, v9.4s // ....................................................................~.............'...................................................................*............. + // trn1 v27.4s, v10.4s, v11.4s // .....................................................................~............'....................................................................*............ + // trn2 v28.4s, v10.4s, v11.4s // ......................................................................~...........'.....................................................................*........... + // trn2 v10.2d, v25.2d, v27.2d // .........................................................................~........'........................................................................*........ + // trn2 v11.2d, v26.2d, v28.2d // .............................................................................~....'............................................................................*.... + // trn1 v8.2d, v25.2d, v27.2d // ........................................................................~.........'.......................................................................*......... + // trn1 v9.2d, v26.2d, v28.2d // ...........................................................................~......'..........................................................................*...... + // str q8, [x0], #(16*4) // ..........................................................................~.......'.........................................................................*....... + // str q9, [x0, #(-16*3)] // ..............................................................................~...'.............................................................................*... + // str q10, [x0, #(-16*2)] // ............................................................................~.....'...........................................................................*..... + // str q11, [x0, #(-16*1)] // .................................................................................~'................................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 47 + // Expected cycles: 52 + // Expected IPC: 0.90 + // + // Cycle bound: 52.0 + // IPC bound: 0.90 + // + // Wall time: 5.32s + // User time: 5.32s + // + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + mls v23.8H, v21.8H, v7.H[0] // *................................................... + add v14.8H, v27.8H, v8.8H // .*.................................................. + ldr q1, [x2, #-32] // ..*................................................. + add v17.8H, v14.8H, v23.8H // ....*............................................... + sub v23.8H, v14.8H, v23.8H // .....*.............................................. + trn2 v11.4S, v12.4S, v5.4S // ......*............................................. + trn1 v27.4S, v12.4S, v5.4S // .......*............................................ + trn2 v2.4S, v17.4S, v23.4S // ........*........................................... + ldr q26, [x2, #-80] // .........*.......................................... + trn2 v14.2D, v2.2D, v11.2D // ...........*........................................ + trn1 v15.4S, v17.4S, v23.4S // ............*....................................... + mul v5.8H, v14.8H, v30.8H // .............*...................................... + sqrdmulh v23.8H, v14.8H, v26.8H // ..............*..................................... + trn2 v17.2D, v15.2D, v27.2D // ...............*.................................... + trn1 v14.2D, v2.2D, v11.2D // ................*................................... + mul v21.8H, v17.8H, v30.8H // .................*.................................. + mls v5.8H, v23.8H, v7.H[0] // ..................*................................. + sqrdmulh v17.8H, v17.8H, v26.8H // ...................*................................ + ldr q2, [x2, #-48] // ....................*............................... + sub v23.8H, v14.8H, v5.8H // ......................*............................. + add v14.8H, v14.8H, v5.8H // .......................*............................ + mls v21.8H, v17.8H, v7.H[0] // ........................*........................... + mul v1.8H, v23.8H, v1.8H // .........................*.......................... + sqrdmulh v17.8H, v23.8H, v13.8H // ..........................*......................... + mul v23.8H, v14.8H, v28.8H // ...........................*........................ + sqrdmulh v14.8H, v14.8H, v2.8H // ............................*....................... + trn1 v28.2D, v15.2D, v27.2D // .............................*...................... + mls v1.8H, v17.8H, v7.H[0] // ..............................*..................... + sub v11.8H, v28.8H, v21.8H // ...............................*.................... + mls v23.8H, v14.8H, v7.H[0] // ................................*................... + add v17.8H, v28.8H, v21.8H // .................................*.................. + sub v14.8H, v11.8H, v1.8H // ..................................*................. + add v1.8H, v11.8H, v1.8H // ...................................*................ + sub v28.8H, v17.8H, v23.8H // ....................................*............... + add v2.8H, v17.8H, v23.8H // .....................................*.............. + trn1 v23.4S, v1.4S, v14.4S // ......................................*............. + trn2 v14.4S, v1.4S, v14.4S // .......................................*............ + trn2 v17.4S, v2.4S, v28.4S // ........................................*........... + trn1 v28.4S, v2.4S, v28.4S // .........................................*.......... + trn2 v1.2D, v17.2D, v14.2D // ...........................................*........ + trn1 v14.2D, v17.2D, v14.2D // ............................................*....... + str q1, [x0, #48] // .............................................*...... + trn2 v1.2D, v28.2D, v23.2D // ..............................................*..... + str q14, [x0, #16] // ...............................................*.... + trn1 v14.2D, v28.2D, v23.2D // ................................................*... + str q1, [x0, #32] // .................................................*.. + str q14, [x0], #(16*4) // ...................................................* + + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + // mls v23.8H, v21.8H, v7.H[0] // *................................................... + // add v6.8H, v27.8H, v8.8H // .*.................................................. + // trn1 v11.4S, v12.4S, v5.4S // .......*............................................ + // sub v0.8H, v6.8H, v23.8H // .....*.............................................. + // ldr q16, [x2, #-80] // .........*.......................................... + // add v26.8H, v6.8H, v23.8H // ....*............................................... + // trn1 v27.4S, v26.4S, v0.4S // ............*....................................... + // trn2 v25.4S, v12.4S, v5.4S // ......*............................................. + // trn2 v12.4S, v26.4S, v0.4S // ........*........................................... + // trn2 v19.2D, v27.2D, v11.2D // ...............*.................................... + // trn2 v10.2D, v12.2D, v25.2D // ...........*........................................ + // sqrdmulh v0.8H, v19.8H, v16.8H // ...................*................................ + // sqrdmulh v18.8H, v10.8H, v16.8H // ..............*..................................... + // trn1 v16.2D, v27.2D, v11.2D // .............................*...................... + // trn1 v2.2D, v12.2D, v25.2D // ................*................................... + // mul v12.8H, v10.8H, v30.8H // .............*...................................... + // mul v10.8H, v19.8H, v30.8H // .................*.................................. + // ldr q14, [x2, #-48] // ....................*............................... + // mls v10.8H, v0.8H, v7.H[0] // ........................*........................... + // mls v12.8H, v18.8H, v7.H[0] // ..................*................................. + // add v9.8H, v16.8H, v10.8H // .................................*.................. + // sub v16.8H, v16.8H, v10.8H // ...............................*.................... + // sub v25.8H, v2.8H, v12.8H // ......................*............................. + // add v30.8H, v2.8H, v12.8H // .......................*............................ + // sqrdmulh v22.8H, v25.8H, v13.8H // ..........................*......................... + // sqrdmulh v13.8H, v30.8H, v14.8H // ............................*....................... + // ldr q14, [x2, #-32] // ..*................................................. + // mul v5.8H, v30.8H, v28.8H // ...........................*........................ + // mul v26.8H, v25.8H, v14.8H // .........................*.......................... + // mls v5.8H, v13.8H, v7.H[0] // ................................*................... + // mls v26.8H, v22.8H, v7.H[0] // ..............................*..................... + // add v13.8H, v9.8H, v5.8H // .....................................*.............. + // sub v9.8H, v9.8H, v5.8H // ....................................*............... + // sub v5.8H, v16.8H, v26.8H // ..................................*................. + // add v25.8H, v16.8H, v26.8H // ...................................*................ + // trn1 v15.4S, v13.4S, v9.4S // .........................................*.......... + // trn2 v3.4S, v13.4S, v9.4S // ........................................*........... + // trn1 v13.4S, v25.4S, v5.4S // ......................................*............. + // trn2 v31.4S, v25.4S, v5.4S // .......................................*............ + // trn1 v2.2D, v15.2D, v13.2D // ................................................*... + // trn2 v9.2D, v15.2D, v13.2D // ..............................................*..... + // str q2, [x0], #(16*4) // ...................................................* + // trn1 v29.2D, v3.2D, v31.2D // ............................................*....... + // str q9, [x0, #-32] // .................................................*.. + // trn2 v9.2D, v3.2D, v31.2D // ...........................................*........ + // str q29, [x0, #-48] // ...............................................*.... + // str q9, [x0, #-16] // .............................................*...... + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/opt_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/opt_impl.h new file mode 100644 index 0000000000..b226740261 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/opt_impl.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +#define NTT_BOUND_NATIVE (6 * MLKEM_Q) +static INLINE void ntt_native(poly *data) +{ + ntt_asm_opt(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_opt(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_opt(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_opt(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_opt(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_opt( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_opt(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/optimize.sh b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/optimize.sh new file mode 100755 index 0000000000..9d43dfa80d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/optimize.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env sh +# Copyright (c) 2024 The mlkem-native project authors +# SPDX-License-Identifier: Apache-2.0 + +set -e + +TARGET_NAME="Cortex-A55" +TARGET=Arm_Cortex_A55 + +echo "* polyvec_basemul_acc_montgomery_cached, K=2, ${TARGET_NAME}" + +cp polyvec_clean.S polyvec_opt.S + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k2_clean,polyvec_basemul_acc_montgomery_cached_asm_k2_opt \ + -l k2_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=3, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k3_clean,polyvec_basemul_acc_montgomery_cached_asm_k3_opt \ + -l k3_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=4, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k4_clean,polyvec_basemul_acc_montgomery_cached_asm_k4_opt \ + -l k4_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c sw_pipelining.allow_post \ + -c constraints.stalls_first_attempt=64 + +cp poly_clean.S poly_opt.S + +echo "* poly_reduce, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_reduce_asm_clean,poly_reduce_asm_opt \ + -l loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_mulcache_compute, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_mulcache_compute_asm_clean,poly_mulcache_compute_asm_opt \ + -l mulcache_compute_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_tomont, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_tomont_asm_clean,poly_tomont_asm_opt \ + -l poly_tomont_asm_loop \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * ntt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + ntt_clean.S -o ntt_opt.S \ + -r ntt_asm_clean,ntt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * intt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + intt_clean.S -o intt_opt.S \ + -r intt_asm_clean,intt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_clean.S new file mode 100644 index 0000000000..f70a402215 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_clean.S @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 +loop_start: + ldr q_data, [ptr], #64 + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-64] + + ldr q_data, [ptr, #-48] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-48] + + ldr q_data, [ptr, #-32] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-32] + + ldr q_data, [ptr, #-16] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-16] + + subs count, count, #1 + cbnz count, loop_start + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 +mulcache_compute_loop_start: + ldr q_tmp0, [data_ptr], #32 + ldr q_tmp1, [data_ptr, #-16] + ldr q_zeta, [zeta_ptr], #16 + ldr q_zeta_twisted, [zeta_twisted_ptr], #16 + + // The mulcache of a polynomial a + b*X in Fq[X^2-zeta] is b*zeta; + // Since tmp0 || tmp1 represents multiple such polynomails as + // (a0,b0,a1,b1,...), extract only the odd elements. + uzp2 data_odd.8h, tmp0.8h, tmp1.8h + mulmod dst, data_odd, zeta, zeta_twisted + + str q_dst, [cache_ptr], #16 + + subs count, count, #1 + cbnz count, mulcache_compute_loop_start + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean): + + mov count, #16 +poly_tobytes_asm_clean_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_clean_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 +poly_tomont_asm_loop: + + ldr q_data, [src], #64 + mulmod res, data, factor, factor_t + str q_res, [src, #-64] + + ldr q_data, [src, #-48] + mulmod res, data, factor, factor_t + str q_res, [src, #-48] + + ldr q_data, [src, #-32] + mulmod res, data, factor, factor_t + str q_res, [src, #-32] + + ldr q_data, [src, #-16] + mulmod res, data, factor, factor_t + str q_res, [src, #-16] + + sub count, count, #1 + cbnz count, poly_tomont_asm_loop + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_opt.S new file mode 100644 index 0000000000..e58ee77c46 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/poly_opt.S @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 + // Instructions: 15 + // Expected cycles: 22 + // Expected IPC: 0.68 + + // Cycle bound: 22.0 + // IPC bound: 0.68 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q21, [x0, #32] // *............................. + ldr q23, [x0, #48] // ..*........................... + sqdmulh v7.8H, v21.8H, v4.H[0] // ....*......................... + sqdmulh v30.8H, v23.8H, v4.H[0] // ......*....................... + srshr v7.8H, v7.8H, #11 // ........*..................... + srshr v30.8H, v30.8H, #11 // ..........*................... + mls v21.8H, v7.8H, v3.H[0] // ...........*.................. + mls v23.8H, v30.8H, v3.H[0] // .............*................ + ldr q5, [x0, #16] // ..............*............... + sshr v7.8H, v21.8H, #15 // ................*............. + sshr v30.8H, v23.8H, #15 // .................*............ + and v7.16B, v3.16B, v7.16B // ..................*........... + add v21.8H, v21.8H, v7.8H // ...................*.......... + and v7.16B, v3.16B, v30.16B // ....................*......... + add v16.8H, v23.8H, v7.8H // .....................*........ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q30, [x0, #32] // *.............................. + // sqdmulh v22.8H, v30.8H, v4.H[0] // ....*.......................... + // ldr q2, [x0, #48] // ..*............................ + // srshr v19.8H, v22.8H, #11 // ........*...................... + // mls v30.8H, v19.8H, v3.H[0] // ...........*................... + // sqdmulh v25.8H, v2.8H, v4.H[0] // ......*........................ + // sshr v31.8H, v30.8H, #15 // ................*.............. + // srshr v25.8H, v25.8H, #11 // ..........*.................... + // and v18.16B, v3.16B, v31.16B // ..................*............ + // mls v2.8H, v25.8H, v3.H[0] // .............*................. + // add v21.8H, v30.8H, v18.8H // ...................*........... + // ldr q5, [x0, #16] // ..............*................ + // sshr v18.8H, v2.8H, #15 // .................*............. + // and v27.16B, v3.16B, v18.16B // ....................*.......... + // add v16.8H, v2.8H, v27.8H // .....................*......... + + sub count, count, #1 +1: + // Instructions: 32 + // Expected cycles: 36 + // Expected IPC: 0.89 + + // Cycle bound: 36.0 + // IPC bound: 0.89 + + // Wall time: 1.05s + // User time: 1.05s + + // -------- cycle (expected) ---------> + // 0 25 + // |------------------------|---------- + ldr q6, [x0], #64 // *................................... + ldr q30, [x0, #32] // ..e................................. + sqdmulh v31.8H, v6.8H, v4.H[0] // ....*............................... + sqdmulh v29.8H, v5.8H, v4.H[0] // .....*.............................. + sqdmulh v22.8H, v30.8H, v4.H[0] // ......e............................. + str q16, [x0, #-16] // .......*............................ + srshr v20.8H, v31.8H, #11 // ........*........................... + srshr v28.8H, v29.8H, #11 // .........*.......................... + str q21, [x0, #-32] // ..........*......................... + mls v6.8H, v20.8H, v3.H[0] // ...........*........................ + mls v5.8H, v28.8H, v3.H[0] // ............*....................... + ldr q2, [x0, #48] // .............e...................... + sshr v31.8H, v6.8H, #15 // ...............*.................... + srshr v19.8H, v22.8H, #11 // ................e................... + and v22.16B, v3.16B, v31.16B // .................*.................. + add v0.8H, v6.8H, v22.8H // ..................*................. + mls v30.8H, v19.8H, v3.H[0] // ...................e................ + sshr v26.8H, v5.8H, #15 // ....................*............... + sqdmulh v25.8H, v2.8H, v4.H[0] // .....................e.............. + and v17.16B, v3.16B, v26.16B // ......................*............. + add v1.8H, v5.8H, v17.8H // .......................*............ + sshr v31.8H, v30.8H, #15 // ........................e........... + srshr v25.8H, v25.8H, #11 // .........................e.......... + str q1, [x0, #-48] // ..........................*......... + and v18.16B, v3.16B, v31.16B // ...........................e........ + mls v2.8H, v25.8H, v3.H[0] // ............................e....... + add v21.8H, v30.8H, v18.8H // .............................e...... + ldr q5, [x0, #16] // ..............................e..... + sshr v18.8H, v2.8H, #15 // ................................e... + str q0, [x0, #-64] // .................................*.. + and v27.16B, v3.16B, v18.16B // ..................................e. + add v16.8H, v2.8H, v27.8H // ...................................e + + // ------------------------ cycle (expected) -------------------------> + // 0 25 50 + // |------------------------|------------------------|----------------- + // ldr q0, [x0], #64 // ..................................*................................. + // sqdmulh v1.8h, v0.8h, v4.h[0] // ..~...............................'...*............................. + // srshr v1.8h, v1.8h, #11 // ......~...........................'.......*......................... + // mls v0.8h, v1.8h, v3.h[0] // .........~........................'..........*...................... + // sshr v2.8h, v0.8h, #15 // .............~....................'..............*.................. + // and v2.16b, v3.16b, v2.16b // ...............~..................'................*................ + // add v0.8h, v0.8h, v2.8h // ................~.................'.................*............... + // str q0, [x0, #-64] // ...............................~..'................................* + // ldr q0, [x0, #-48] // ............................e.....'.............................~... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...~..............................'....*............................ + // srshr v1.8h, v1.8h, #11 // .......~..........................'........*........................ + // mls v0.8h, v1.8h, v3.h[0] // ..........~.......................'...........*..................... + // sshr v2.8h, v0.8h, #15 // ..................~...............'...................*............. + // and v2.16b, v3.16b, v2.16b // ....................~.............'.....................*........... + // add v0.8h, v0.8h, v2.8h // .....................~............'......................*.......... + // str q0, [x0, #-48] // ........................~.........'.........................*....... + // ldr q0, [x0, #-32] // e.................................'.~............................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ....e.............................'.....~........................... + // srshr v1.8h, v1.8h, #11 // ..............e...................'...............~................. + // mls v0.8h, v1.8h, v3.h[0] // .................e................'..................~.............. + // sshr v2.8h, v0.8h, #15 // ......................e...........'.......................~......... + // and v2.16b, v3.16b, v2.16b // .........................e........'..........................~...... + // add v0.8h, v0.8h, v2.8h // ...........................e......'............................~.... + // str q0, [x0, #-32] // ........~.........................'.........*....................... + // ldr q0, [x0, #-16] // ...........e......................'............~.................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...................e..............'....................~............ + // srshr v1.8h, v1.8h, #11 // .......................e..........'........................~........ + // mls v0.8h, v1.8h, v3.h[0] // ..........................e.......'...........................~..... + // sshr v2.8h, v0.8h, #15 // ..............................e...'...............................~. + // and v2.16b, v3.16b, v2.16b // ................................e.'................................. + // add v0.8h, v0.8h, v2.8h // .................................e'................................. + // str q0, [x0, #-16] // .....~............................'......*.......................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 17 + // Expected cycles: 23 + // Expected IPC: 0.74 + + // Cycle bound: 23.0 + // IPC bound: 0.74 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + sqdmulh v20.8H, v5.8H, v4.H[0] // *............................. + ldr q24, [x0], #64 // .*............................ + str q21, [x0, #-32] // ...*.......................... + srshr v20.8H, v20.8H, #11 // ....*......................... + sqdmulh v25.8H, v24.8H, v4.H[0] // .....*........................ + str q16, [x0, #-16] // ......*....................... + mls v5.8H, v20.8H, v3.H[0] // .......*...................... + srshr v20.8H, v25.8H, #11 // .........*.................... + sshr v2.8H, v5.8H, #15 // ...........*.................. + mls v24.8H, v20.8H, v3.H[0] // ............*................. + and v20.16B, v3.16B, v2.16B // .............*................ + add v31.8H, v5.8H, v20.8H // ..............*............... + sshr v20.8H, v24.8H, #15 // ................*............. + str q31, [x0, #-48] // .................*............ + and v31.16B, v3.16B, v20.16B // ..................*........... + add v24.8H, v24.8H, v31.8H // ...................*.......... + str q24, [x0, #-64] // ......................*....... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q6, [x0], #64 // .*............................. + // sqdmulh v31.8H, v6.8H, v4.H[0] // .....*......................... + // sqdmulh v29.8H, v5.8H, v4.H[0] // *.............................. + // str q16, [x0, #-16] // ......*........................ + // srshr v20.8H, v31.8H, #11 // .........*..................... + // srshr v28.8H, v29.8H, #11 // ....*.......................... + // str q21, [x0, #-32] // ...*........................... + // mls v6.8H, v20.8H, v3.H[0] // ............*.................. + // mls v5.8H, v28.8H, v3.H[0] // .......*....................... + // sshr v31.8H, v6.8H, #15 // ................*.............. + // and v22.16B, v3.16B, v31.16B // ..................*............ + // add v0.8H, v6.8H, v22.8H // ...................*........... + // sshr v26.8H, v5.8H, #15 // ...........*................... + // and v17.16B, v3.16B, v26.16B // .............*................. + // add v1.8H, v5.8H, v17.8H // ..............*................ + // str q1, [x0, #-48] // .................*............. + // str q0, [x0, #-64] // ......................*........ + + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 + // Instructions: 7 + // Expected cycles: 12 + // Expected IPC: 0.58 + + // Cycle bound: 12.0 + // IPC bound: 0.58 + + // Wall time: 0.01s + // User time: 0.01s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q1, [x1, #16] // *............................. + ldr q27, [x1], #32 // ..*........................... + ldr q23, [x2], #16 // ....*......................... + uzp2 v27.8H, v27.8H, v1.8H // ......*....................... + ldr q1, [x3], #16 // .......*...................... + mul v2.8H, v27.8H, v23.8H // .........*.................... + sqrdmulh v27.8H, v27.8H, v1.8H // ...........*.................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q29, [x1, #16] // *.............................. + // ldr q21, [x2], #16 // ....*.......................... + // ldr q27, [x1], #32 // ..*............................ + // ldr q7, [x3], #16 // .......*....................... + // uzp2 v28.8H, v27.8H, v29.8H // ......*........................ + // mul v2.8H, v28.8H, v21.8H // .........*..................... + // sqrdmulh v27.8H, v28.8H, v7.8H // ...........*................... + + sub count, count, #1 +1: + // Instructions: 9 + // Expected cycles: 13 + // Expected IPC: 0.69 + + // Cycle bound: 13.0 + // IPC bound: 0.69 + + // Wall time: 0.09s + // User time: 0.09s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q29, [x1, #16] // e............................. + ldr q21, [x2], #16 // ..e........................... + mls v2.8H, v27.8H, v6.H[0] // ....*......................... + ldr q27, [x1], #32 // .....e........................ + ldr q7, [x3], #16 // .......e...................... + uzp2 v28.8H, v27.8H, v29.8H // .........e.................... + str q2, [x0], #16 // ..........*................... + mul v2.8H, v28.8H, v21.8H // ...........e.................. + sqrdmulh v27.8H, v28.8H, v7.8H // ............e................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q3, [x1], #32 // .....e.......'....~.......'.... + // ldr q4, [x1, #-16] // e............~............~.... + // ldr q1, [x2], #16 // ..e..........'.~..........'.~.. + // ldr q2, [x3], #16 // .......e.....'......~.....'.... + // uzp2 v0.8h, v3.8h, v4.8h // .........e...'........~...'.... + // sqrdmulh v3.8h, v0.8h, v2.8h // ............e'...........~'.... + // mul v5.8h, v0.8h, v1.8h // ...........e.'..........~.'.... + // mls v5.8h, v3.8h, v6.h[0] // ....~........'...*........'.... + // str q5, [x0], #16 // ..........~..'.........*..'.... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 2 + // Expected cycles: 5 + // Expected IPC: 0.40 + + // Cycle bound: 5.0 + // IPC bound: 0.40 + + // Wall time: 0.00s + // User time: 0.00s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v2.8H, v27.8H, v6.H[0] // *............................. + str q2, [x0], #16 // ....*......................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v2.8H, v27.8H, v6.H[0] // *.............................. + // str q2, [x0], #16 // ....*.......................... + + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt): + + mov count, #16 +poly_tobytes_asm_opt_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_opt_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 + // Instructions: 5 + // Expected cycles: 7 + // Expected IPC: 0.71 + // + // Cycle bound: 7.0 + // IPC bound: 0.71 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x0, #48] // *............................. + ldr q23, [x0, #16] // ..*........................... + mul v17.8H, v26.8H, v2.8H // ....*......................... + sqrdmulh v7.8H, v26.8H, v3.8H // .....*........................ + ldr q27, [x0, #32] // ......*....................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q7, [x0, #48] // *.............................. + // ldr q23, [x0, #16] // ..*............................ + // mul v17.8H, v7.8H, v2.8H // ....*.......................... + // sqrdmulh v7.8H, v7.8H, v3.8H // .....*......................... + // ldr q27, [x0, #32] // ......*........................ + + sub count, count, #1 +1: + // Instructions: 20 + // Expected cycles: 24 + // Expected IPC: 0.83 + // + // Cycle bound: 24.0 + // IPC bound: 0.83 + // + // Wall time: 0.73s + // User time: 0.73s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v5.8H, v23.8H, v3.8H // .*............................ + ldr q7, [x0], #64 // ..*........................... + str q17, [x0, #-16] // ....*......................... + sqrdmulh v29.8H, v27.8H, v3.8H // .....*........................ + sqrdmulh v19.8H, v7.8H, v3.8H // ......*....................... + mul v25.8H, v23.8H, v2.8H // .......*...................... + mul v0.8H, v7.8H, v2.8H // ........*..................... + mul v26.8H, v27.8H, v2.8H // .........*.................... + ldr q7, [x0, #48] // ..........e................... + mls v25.8H, v5.8H, v4.H[0] // ............*................. + ldr q23, [x0, #16] // .............e................ + mls v26.8H, v29.8H, v4.H[0] // ...............*.............. + mls v0.8H, v19.8H, v4.H[0] // ................*............. + str q25, [x0, #-48] // .................*............ + mul v17.8H, v7.8H, v2.8H // ..................e........... + sqrdmulh v7.8H, v7.8H, v3.8H // ...................e.......... + str q0, [x0, #-64] // ....................*......... + ldr q27, [x0, #32] // .....................e........ + str q26, [x0, #-32] // .......................*...... + + // --------- cycle (expected) ----------> + // 0 25 + // |------------------------|------------ + // ldr q0, [x0], #64 // ..............'.*..................... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'.....*................. + // mul v1.8h, v0.8h, v2.8h // ..............'.......*............... + // mls v1.8h, v6.8h, v4.h[0] // ......~.......'...............*....... + // str q1, [x0, #-64] // ..........~...'...................*... + // ldr q0, [x0, #-48] // ...e..........'............~.......... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'*...................... + // mul v1.8h, v0.8h, v2.8h // ..............'......*................ + // mls v1.8h, v6.8h, v4.h[0] // ..~...........'...........*........... + // str q1, [x0, #-48] // .......~......'................*...... + // ldr q0, [x0, #-32] // ...........e..'....................~.. + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'....*.................. + // mul v1.8h, v0.8h, v2.8h // ..............'........*.............. + // mls v1.8h, v6.8h, v4.h[0] // .....~........'..............*........ + // str q1, [x0, #-32] // .............~'......................* + // ldr q0, [x0, #-16] // e.............'.........~............. + // sqrdmulh v6.8h, v0.8h, v3.8h // .........e....'..................~.... + // mul v1.8h, v0.8h, v2.8h // ........e.....'.................~..... + // mls v1.8h, v6.8h, v4.h[0] // ..............*....................... + // str q1, [x0, #-16] // ..............'...*................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 15 + // Expected cycles: 18 + // Expected IPC: 0.83 + // + // Cycle bound: 18.0 + // IPC bound: 0.83 + // + // Wall time: 0.07s + // User time: 0.07s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v7.8H, v23.8H, v3.8H // .*............................ + mul v26.8H, v23.8H, v2.8H // ..*........................... + sqrdmulh v25.8H, v27.8H, v3.8H // ...*.......................... + ldr q23, [x0], #64 // ....*......................... + mul v27.8H, v27.8H, v2.8H // ......*....................... + mls v26.8H, v7.8H, v4.H[0] // .......*...................... + sqrdmulh v7.8H, v23.8H, v3.8H // ........*..................... + mul v23.8H, v23.8H, v2.8H // .........*.................... + str q17, [x0, #-16] // ..........*................... + mls v27.8H, v25.8H, v4.H[0] // ...........*.................. + str q26, [x0, #-48] // ............*................. + mls v23.8H, v7.8H, v4.H[0] // .............*................ + str q27, [x0, #-32] // ...............*.............. + str q23, [x0, #-64] // .................*............ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v17.8H, v7.8H, v4.H[0] // *.............................. + // sqrdmulh v5.8H, v23.8H, v3.8H // .*............................. + // ldr q7, [x0], #64 // ....*.......................... + // str q17, [x0, #-16] // ..........*.................... + // sqrdmulh v29.8H, v27.8H, v3.8H // ...*........................... + // sqrdmulh v19.8H, v7.8H, v3.8H // ........*...................... + // mul v25.8H, v23.8H, v2.8H // ..*............................ + // mul v0.8H, v7.8H, v2.8H // .........*..................... + // mul v26.8H, v27.8H, v2.8H // ......*........................ + // mls v25.8H, v5.8H, v4.H[0] // .......*....................... + // mls v26.8H, v29.8H, v4.H[0] // ...........*................... + // mls v0.8H, v19.8H, v4.H[0] // .............*................. + // str q25, [x0, #-48] // ............*.................. + // str q0, [x0, #-64] // .................*............. + // str q26, [x0, #-32] // ...............*............... + + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_clean.S new file mode 100644 index 0000000000..99fb05de5d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_clean.S @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +// +// AArch64 re-implementation of the asymmetric base multiplication from: +// +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. +// +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) +k2_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k2_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) +k3_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k3_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + // + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) +k4_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr, b3_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k4_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_opt.S new file mode 100644 index 0000000000..16ed77c3fc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/polyvec_opt.S @@ -0,0 +1,1584 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// AArch64 re-implementation of the asymmetric base multiplication from: + +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. + +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 94 + // Expected IPC: 0.80 + + // Cycle bound: 94.0 + // IPC bound: 0.80 + + // Wall time: 1.49s + // User time: 1.49s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q9, [x4], #32 // *.......................................................................... + ldr q5, [x4, #-16] // ......*.................................................................... + ldr q11, [x5], #32 // .*......................................................................... + uzp1 v23.8H, v9.8H, v5.8H // .........*................................................................. + uzp2 v9.8H, v9.8H, v5.8H // .....................*..................................................... + ldr q5, [x2], #32 // ..*........................................................................ + ldr q7, [x5, #-16] // ..............*............................................................ + ldr q21, [x2, #-16] // ...*....................................................................... + uzp2 v10.8H, v11.8H, v7.8H // .................*......................................................... + uzp1 v11.8H, v11.8H, v7.8H // ..................*........................................................ + uzp1 v7.8H, v5.8H, v21.8H // ....*...................................................................... + uzp2 v5.8H, v5.8H, v21.8H // .....*..................................................................... + ldr q21, [x1], #32 // .......*................................................................... + ldr q25, [x1, #-16] // ........*.................................................................. + ld1 {v6.8H}, [x3], #16 // ............................*.............................................. + uzp1 v26.8H, v21.8H, v25.8H // ..........*................................................................ + uzp2 v21.8H, v21.8H, v25.8H // ...........*............................................................... + smull v25.4S, v26.4H, v5.4H // ............*.............................................................. + smull2 v5.4S, v26.8H, v5.8H // .............*............................................................. + smull v19.4S, v26.4H, v7.4H // ..........................*................................................ + smull2 v26.4S, v26.8H, v7.8H // ..............................*............................................ + smlal v25.4S, v21.4H, v7.4H // ...............*........................................................... + smlal2 v5.4S, v21.8H, v7.8H // ................*.......................................................... + smlal v19.4S, v21.4H, v6.4H // ...................................*....................................... + smlal2 v26.4S, v21.8H, v6.8H // .................................*......................................... + smlal v25.4S, v23.4H, v10.4H // ...................*....................................................... + smlal2 v5.4S, v23.8H, v10.8H // ....................*...................................................... + smlal v19.4S, v23.4H, v11.4H // ......................................*.................................... + smlal2 v26.4S, v23.8H, v11.8H // ....................................*...................................... + ld1 {v23.8H}, [x6], #16 // ........................*.................................................. + smlal v25.4S, v9.4H, v11.4H // ......................*.................................................... + smlal2 v5.4S, v9.8H, v11.8H // .......................*................................................... + smlal2 v26.4S, v9.8H, v23.8H // .......................................*................................... + smlal v19.4S, v9.4H, v23.4H // .........................................*................................. + ldr q9, [x4], #32 // ...............................*........................................... + uzp1 v11.8H, v25.8H, v5.8H // .........................*................................................. + uzp1 v23.8H, v19.8H, v26.8H // .............................................*............................. + mul v11.8H, v11.8H, v2.8H // ...........................*............................................... + mul v23.8H, v23.8H, v2.8H // ..............................................*............................ + ldr q7, [x5], #32 // ................................*.......................................... + smlal2 v5.4S, v11.8H, v0.8H // .............................*............................................. + smlal v25.4S, v11.4H, v0.4H // ..................................*........................................ + ldr q11, [x2], #32 // .....................................*..................................... + ldr q21, [x2, #-16] // ........................................*.................................. + ldr q6, [x4, #-16] // ...............................................*........................... + uzp1 v17.8H, v11.8H, v21.8H // ...........................................*............................... + ldr q10, [x1], #32 // ................................................*.......................... + ldr q29, [x1, #-16] // .................................................*......................... + uzp2 v11.8H, v11.8H, v21.8H // ............................................*.............................. + uzp1 v13.8H, v9.8H, v6.8H // ...................................................*....................... + uzp1 v3.8H, v10.8H, v29.8H // ....................................................*...................... + uzp2 v10.8H, v10.8H, v29.8H // .....................................................*..................... + smull v12.4S, v3.4H, v11.4H // ......................................................*.................... + smull2 v11.4S, v3.8H, v11.8H // .......................................................*................... + ldr q21, [x5, #-16] // ........................................................*.................. + smlal v12.4S, v10.4H, v17.4H // .........................................................*................. + smlal2 v11.4S, v10.8H, v17.8H // ..........................................................*................ + uzp2 v29.8H, v7.8H, v21.8H // ...........................................................*............... + uzp1 v15.8H, v7.8H, v21.8H // ............................................................*.............. + smlal v12.4S, v13.4H, v29.4H // .............................................................*............. + smlal2 v11.4S, v13.8H, v29.8H // ..............................................................*............ + uzp2 v28.8H, v9.8H, v6.8H // ...............................................................*........... + smlal2 v26.4S, v23.8H, v0.8H // ..................................................*........................ + smlal v12.4S, v28.4H, v15.4H // .................................................................*......... + smlal2 v11.4S, v28.8H, v15.8H // ..................................................................*........ + smlal v19.4S, v23.4H, v0.4H // ................................................................*.......... + uzp2 v27.8H, v25.8H, v5.8H // ..........................................*................................ + smull v23.4S, v3.4H, v17.4H // ......................................................................*.... + uzp1 v9.8H, v12.8H, v11.8H // .....................................................................*..... + uzp2 v19.8H, v19.8H, v26.8H // ....................................................................*...... + mul v14.8H, v9.8H, v2.8H // .......................................................................*... + ld1 {v22.8H}, [x6], #16 // ...................................................................*....... + zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + smlal2 v11.4S, v14.8H, v0.8H // ..........................................................................* + ld1 {v4.8H}, [x3], #16 // .........................................................................*. + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q18, [x4], #32 // *.......................................................................... + // ldr q30, [x5], #32 // ..*........................................................................ + // ldr q8, [x2], #32 // .....*..................................................................... + // ldr q9, [x2, #-16] // .......*................................................................... + // uzp1 v17.8H, v8.8H, v9.8H // ..........*................................................................ + // uzp2 v4.8H, v8.8H, v9.8H // ...........*............................................................... + // ldr q19, [x4, #-16] // .*......................................................................... + // ldr q29, [x1], #32 // ............*.............................................................. + // ldr q12, [x1, #-16] // .............*............................................................. + // uzp1 v13.8H, v18.8H, v19.8H // ...*....................................................................... + // uzp1 v3.8H, v29.8H, v12.8H // ...............*........................................................... + // uzp2 v10.8H, v29.8H, v12.8H // ................*.......................................................... + // smull v12.4S, v3.4H, v4.4H // .................*......................................................... + // smull2 v11.4S, v3.8H, v4.8H // ..................*........................................................ + // ldr q5, [x5, #-16] // ......*.................................................................... + // smlal v12.4S, v10.4H, v17.4H // .....................*..................................................... + // smlal2 v11.4S, v10.8H, v17.8H // ......................*.................................................... + // uzp2 v14.8H, v30.8H, v5.8H // ........*.................................................................. + // uzp1 v15.8H, v30.8H, v5.8H // .........*................................................................. + // smlal v12.4S, v13.4H, v14.4H // .........................*................................................. + // smlal2 v11.4S, v13.8H, v14.8H // ..........................*................................................ + // uzp2 v28.8H, v18.8H, v19.8H // ....*...................................................................... + // smlal v12.4S, v28.4H, v15.4H // ..............................*............................................ + // smlal2 v11.4S, v28.8H, v15.8H // ...............................*........................................... + // ld1 {v22.8H}, [x6], #16 // .............................*............................................. + // uzp1 v1.8H, v12.8H, v11.8H // ...................................*....................................... + // smull v23.4S, v3.4H, v17.4H // ...................*....................................................... + // mul v14.8H, v1.8H, v2.8H // .....................................*..................................... + // ld1 {v4.8H}, [x3], #16 // ..............*............................................................ + // smlal2 v11.4S, v14.8H, v0.8H // ........................................*.................................. + // smull2 v20.4S, v3.8H, v17.8H // ....................*...................................................... + // ldr q18, [x4], #32 // ..................................*........................................ + // ldr q30, [x5], #32 // .......................................*................................... + // smlal2 v20.4S, v10.8H, v4.8H // ........................*.................................................. + // smlal v12.4S, v14.4H, v0.4H // .........................................*................................. + // smlal v23.4S, v10.4H, v4.4H // .......................*................................................... + // smlal2 v20.4S, v13.8H, v15.8H // ............................*.............................................. + // ldr q8, [x2], #32 // ..........................................*................................ + // smlal v23.4S, v13.4H, v15.4H // ...........................*............................................... + // smlal2 v20.4S, v28.8H, v22.8H // ................................*.......................................... + // ldr q9, [x2, #-16] // ...........................................*............................... + // smlal v23.4S, v28.4H, v22.4H // .................................*......................................... + // uzp2 v27.8H, v12.8H, v11.8H // ..................................................................*........ + // uzp1 v17.8H, v8.8H, v9.8H // .............................................*............................. + // uzp2 v4.8H, v8.8H, v9.8H // ................................................*.......................... + // uzp1 v5.8H, v23.8H, v20.8H // ....................................*...................................... + // mul v31.8H, v5.8H, v2.8H // ......................................*.................................... + // ldr q19, [x4, #-16] // ............................................*.............................. + // ldr q29, [x1], #32 // ..............................................*............................ + // ldr q12, [x1, #-16] // ...............................................*........................... + // smlal2 v20.4S, v31.8H, v0.8H // ..............................................................*............ + // uzp1 v13.8H, v18.8H, v19.8H // .................................................*......................... + // uzp1 v3.8H, v29.8H, v12.8H // ..................................................*........................ + // uzp2 v10.8H, v29.8H, v12.8H // ...................................................*....................... + // smull v12.4S, v3.4H, v4.4H // ....................................................*...................... + // smull2 v11.4S, v3.8H, v4.8H // .....................................................*..................... + // ldr q5, [x5, #-16] // ......................................................*.................... + // smlal v12.4S, v10.4H, v17.4H // .......................................................*................... + // smlal2 v11.4S, v10.8H, v17.8H // ........................................................*.................. + // uzp2 v14.8H, v30.8H, v5.8H // .........................................................*................. + // uzp1 v15.8H, v30.8H, v5.8H // ..........................................................*................ + // smlal v12.4S, v13.4H, v14.4H // ...........................................................*............... + // smlal2 v11.4S, v13.8H, v14.8H // ............................................................*.............. + // uzp2 v28.8H, v18.8H, v19.8H // .............................................................*............. + // smlal v23.4S, v31.4H, v0.4H // .................................................................*......... + // smlal v12.4S, v28.4H, v15.4H // ...............................................................*........... + // smlal2 v11.4S, v28.8H, v15.8H // ................................................................*.......... + // ld1 {v22.8H}, [x6], #16 // .......................................................................*... + // uzp2 v19.8H, v23.8H, v20.8H // .....................................................................*..... + // uzp1 v1.8H, v12.8H, v11.8H // ....................................................................*...... + // smull v23.4S, v3.4H, v17.4H // ...................................................................*....... + // mul v14.8H, v1.8H, v2.8H // ......................................................................*.... + // zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + // ld1 {v4.8H}, [x3], #16 // ..........................................................................* + // smlal2 v11.4S, v14.8H, v0.8H // .........................................................................*. + + sub count, count, #2 +1: + // Instructions: 48 + // Expected cycles: 58 + // Expected IPC: 0.83 + + // Cycle bound: 58.0 + // IPC bound: 0.83 + + // Wall time: 6.39s + // User time: 6.39s + + // -------------- original position --------------> + // 0 25 + // |------------------------|---------------------- + smull2 v20.4S, v3.8H, v17.8H // ..........*..................................... + ldr q18, [x4], #32 // .................e.............................. + ldr q30, [x5], #32 // .....................e.......................... + smlal2 v20.4S, v10.8H, v4.8H // ............*................................... + smlal v12.4S, v14.4H, v0.4H // .........................................*...... + smlal v23.4S, v10.4H, v4.4H // ...........*.................................... + str q9, [x0, #16] // ...............................................l + smlal2 v20.4S, v13.8H, v15.8H // ...........................*.................... + ldr q8, [x2], #32 // ....e........................................... + smlal v23.4S, v13.4H, v15.4H // ..........................*..................... + smlal2 v20.4S, v28.8H, v22.8H // .............................*.................. + zip1 v26.8H, v19.8H, v27.8H // ............................................l... + ldr q9, [x2, #-16] // .....e.......................................... + smlal v23.4S, v28.4H, v22.4H // ............................*................... + uzp2 v27.8H, v12.8H, v11.8H // ...........................................*.... + uzp1 v17.8H, v8.8H, v9.8H // ......e......................................... + uzp2 v4.8H, v8.8H, v9.8H // .......e........................................ + uzp1 v5.8H, v23.8H, v20.8H // ..................................*............. + str q26, [x0], #32 // ..............................................l. + mul v31.8H, v5.8H, v2.8H // ...................................*............ + ldr q19, [x4, #-16] // ..................e............................. + ldr q29, [x1], #32 // e............................................... + ldr q12, [x1, #-16] // .e.............................................. + smlal2 v20.4S, v31.8H, v0.8H // .....................................*.......... + uzp1 v13.8H, v18.8H, v19.8H // ...................e............................ + uzp1 v3.8H, v29.8H, v12.8H // ..e............................................. + uzp2 v10.8H, v29.8H, v12.8H // ...e............................................ + smull v12.4S, v3.4H, v4.4H // .............e.................................. + smull2 v11.4S, v3.8H, v4.8H // ..............e................................. + ldr q5, [x5, #-16] // ......................e......................... + smlal v12.4S, v10.4H, v17.4H // ...............e................................ + smlal2 v11.4S, v10.8H, v17.8H // ................e............................... + uzp2 v14.8H, v30.8H, v5.8H // ........................e....................... + uzp1 v15.8H, v30.8H, v5.8H // .......................e........................ + smlal v12.4S, v13.4H, v14.4H // ..............................e................. + smlal2 v11.4S, v13.8H, v14.8H // ...............................e................ + uzp2 v28.8H, v18.8H, v19.8H // ....................e........................... + smlal v23.4S, v31.4H, v0.4H // ....................................*........... + smlal v12.4S, v28.4H, v15.4H // ................................e............... + smlal2 v11.4S, v28.8H, v15.8H // .................................e.............. + ld1 {v22.8H}, [x6], #16 // .........................e...................... + uzp2 v19.8H, v23.8H, v20.8H // ......................................*......... + uzp1 v1.8H, v12.8H, v11.8H // .......................................e........ + smull v23.4S, v3.4H, v17.4H // .........e...................................... + mul v14.8H, v1.8H, v2.8H // ........................................e....... + zip2 v9.8H, v19.8H, v27.8H // .............................................*.. + ld1 {v4.8H}, [x3], #16 // ........e....................................... + smlal2 v11.4S, v14.8H, v0.8H // ..........................................e..... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q12, [x1], #32 // ....................e..........................'....................~..........................'.................. + // ldr q13, [x1, #-16] // .....................e.........................'.....................~.........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................e......................'........................~......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // .........................e.....................'.........................~.....................'.................. + // ldr q12, [x2], #32 // .......e.......................................'.......~.......................................'.......~.......... + // ldr q13, [x2, #-16] // ...........e...................................'...........~...................................'...........~...... + // uzp1 v5.8h, v12.8h, v13.8h // ..............e................................'..............~................................'..............~... + // uzp2 v6.8h, v12.8h, v13.8h // ...............e...............................'...............~...............................'...............~.. + // ld1 {v7.8h}, [x3], #16 // .............................................e.'.............................................~.'.................. + // smull v8.4s, v3.4h, v5.4h // ..........................................e....'..........................................~....'.................. + // smull2 v10.4s, v3.8h, v5.8h // ...............................................*...............................................~.................. + // smlal v8.4s, v4.4h, v7.4h // ....~..........................................'....*..........................................'....~............. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................'..*............................................'..~............... + // smull v9.4s, v3.4h, v6.4h // ..........................e....................'..........................~....................'.................. + // smull2 v11.4s, v3.8h, v6.8h // ...........................e...................'...........................~...................'.................. + // smlal v9.4s, v4.4h, v5.4h // .............................e.................'.............................~.................'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ..............................e................'..............................~................'.................. + // ldr q12, [x4], #32 // e..............................................'~..............................................'~................. + // ldr q13, [x4, #-16] // ...................e...........................'...................~...........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // .......................e.......................'.......................~.......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // ...................................e...........'...................................~...........'.................. + // ldr q12, [x5], #32 // .e.............................................'.~.............................................'.~................ + // ldr q13, [x5, #-16] // ............................e..................'............................~..................'.................. + // uzp1 v5.8h, v12.8h, v13.8h // ................................e..............'................................~..............'.................. + // uzp2 v6.8h, v12.8h, v13.8h // ...............................e...............'...............................~...............'.................. + // ld1 {v7.8h}, [x6], #16 // .......................................e.......'.......................................~.......'.................. + // smlal v8.4s, v3.4h, v5.4h // ........~......................................'........*......................................'........~......... + // smlal2 v10.4s, v3.8h, v5.8h // ......~........................................'......*........................................'......~........... + // smlal v8.4s, v4.4h, v7.4h // ............~..................................'............*..................................'............~..... + // smlal2 v10.4s, v4.8h, v7.8h // .........~.....................................'.........*.....................................'.........~........ + // smlal v9.4s, v3.4h, v6.4h // .................................e.............'.................................~.............'.................. + // smlal2 v11.4s, v3.8h, v6.8h // ..................................e............'..................................~............'.................. + // smlal v9.4s, v4.4h, v5.4h // .....................................e.........'.....................................~.........'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ......................................e........'......................................~........'.................. + // uzp1 v28.8h, v8.8h, v10.8h // ................~..............................'................*..............................'................~. + // mul v28.8h, v28.8h, v2.8h // ..................~............................'..................*............................'.................. + // smlal v8.4s, v28.4h, v0.4h // ....................................~..........'....................................*..........'.................. + // smlal2 v10.4s, v28.8h, v0.8h // ......................~........................'......................*........................'.................. + // uzp2 v26.8h, v8.8h, v10.8h // ........................................~......'........................................*......'.................. + // uzp1 v28.8h, v9.8h, v11.8h // .........................................e.....'.........................................~.....'.................. + // mul v28.8h, v28.8h, v2.8h // ...........................................e...'...........................................~...'.................. + // smlal v9.4s, v28.4h, v0.4h // ...~...........................................'...*...........................................'...~.............. + // smlal2 v11.4s, v28.8h, v0.8h // ..............................................e'..............................................~'.................. + // uzp2 v27.8h, v9.8h, v11.8h // .............~.................................'.............*.................................'.............~.... + // zip1 v12.8h, v26.8h, v27.8h // ..........~....................................'..........~....................................'..........l....... + // zip2 v13.8h, v26.8h, v27.8h // ............................................~..'............................................*..'.................. + // str q12, [x0], #32 // .................~.............................'.................~.............................'.................l + // str q13, [x0, #-16] // .....~.........................................'.....~.........................................'.....l............ + + sub count, count, #1 + cbnz count, 1b + // Instructions: 21 + // Expected cycles: 35 + // Expected IPC: 0.60 + + // Cycle bound: 35.0 + // IPC bound: 0.60 + + // Wall time: 0.08s + // User time: 0.08s + + // ----- original position -----> + // 0 25 + // |------------------------|---- + smull2 v5.4S, v3.8H, v17.8H // *............................. + smlal v12.4S, v14.4H, v0.4H // ..*........................... + smlal v23.4S, v10.4H, v4.4H // ...*.......................... + str q9, [x0, #16] // ....*......................... + smlal2 v5.4S, v10.8H, v4.8H // .*............................ + uzp2 v11.8H, v12.8H, v11.8H // ..........*................... + zip1 v9.8H, v19.8H, v27.8H // ........*..................... + smlal v23.4S, v13.4H, v15.4H // ......*....................... + smlal2 v5.4S, v13.8H, v15.8H // .....*........................ + str q9, [x0], #32 // ............*................. + smlal v23.4S, v28.4H, v22.4H // .........*.................... + smlal2 v5.4S, v28.8H, v22.8H // .......*...................... + uzp1 v9.8H, v23.8H, v5.8H // ...........*.................. + mul v9.8H, v9.8H, v2.8H // .............*................ + smlal2 v5.4S, v9.8H, v0.8H // ..............*............... + smlal v23.4S, v9.4H, v0.4H // ...............*.............. + uzp2 v9.8H, v23.8H, v5.8H // ................*............. + zip2 v5.8H, v9.8H, v11.8H // .................*............ + zip1 v9.8H, v9.8H, v11.8H // ...................*.......... + str q5, [x0, #16] // ..................*........... + str q9, [x0], #32 // ....................*......... + + // -------- new position --------> + // 0 25 + // |------------------------|----- + // smull2 v20.4S, v3.8H, v17.8H // *.............................. + // smlal2 v20.4S, v10.8H, v4.8H // ....*.......................... + // smlal v12.4S, v14.4H, v0.4H // .*............................. + // smlal v23.4S, v10.4H, v4.4H // ..*............................ + // str q9, [x0, #16] // ...*........................... + // smlal2 v20.4S, v13.8H, v15.8H // ........*...................... + // smlal v23.4S, v13.4H, v15.4H // .......*....................... + // smlal2 v20.4S, v28.8H, v22.8H // ...........*................... + // zip1 v26.8H, v19.8H, v27.8H // ......*........................ + // smlal v23.4S, v28.4H, v22.4H // ..........*.................... + // uzp2 v27.8H, v12.8H, v11.8H // .....*......................... + // uzp1 v5.8H, v23.8H, v20.8H // ............*.................. + // str q26, [x0], #32 // .........*..................... + // mul v31.8H, v5.8H, v2.8H // .............*................. + // smlal2 v20.4S, v31.8H, v0.8H // ..............*................ + // smlal v23.4S, v31.4H, v0.4H // ...............*............... + // uzp2 v19.8H, v23.8H, v20.8H // ................*.............. + // zip2 v9.8H, v19.8H, v27.8H // .................*............. + // str q9, [x0, #16] // ...................*........... + // zip1 v26.8H, v19.8H, v27.8H // ..................*............ + // str q26, [x0], #32 // ....................*.......... + + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 103 + // Expected IPC: 0.73 + + // Cycle bound: 103.0 + // IPC bound: 0.73 + + // Wall time: 0.94s + // User time: 0.94s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q7, [x2, #16] // *.......................................................................... + ldr q20, [x2], #32 // ..*........................................................................ + ldr q15, [x1, #16] // .*......................................................................... + uzp1 v8.8H, v20.8H, v7.8H // ...............*........................................................... + uzp2 v7.8H, v20.8H, v7.8H // ................*.......................................................... + ld1 {v20.8H}, [x3], #16 // ...*....................................................................... + ldr q30, [x1], #32 // ..............*............................................................ + ldr q11, [x4], #32 // ....*...................................................................... + uzp1 v16.8H, v30.8H, v15.8H // .................*......................................................... + uzp2 v15.8H, v30.8H, v15.8H // ..................*........................................................ + smull v30.4S, v16.4H, v7.4H // ...................*....................................................... + smull2 v7.4S, v16.8H, v7.8H // ....................*...................................................... + smull v9.4S, v16.4H, v8.4H // .....................*..................................................... + smull2 v16.4S, v16.8H, v8.8H // ......................*.................................................... + smlal v30.4S, v15.4H, v8.4H // .......................*................................................... + smlal2 v7.4S, v15.8H, v8.8H // ........................*.................................................. + smlal v9.4S, v15.4H, v20.4H // .........................*................................................. + smlal2 v16.4S, v15.8H, v20.8H // ..........................*................................................ + ldr q20, [x4, #-16] // .....*..................................................................... + ldr q15, [x5], #32 // ......*.................................................................... + uzp1 v8.8H, v11.8H, v20.8H // ...........................*............................................... + uzp2 v20.8H, v11.8H, v20.8H // ............................*.............................................. + ldr q11, [x5, #-16] // .......*................................................................... + ld1 {v27.8H}, [x6], #16 // ........*.................................................................. + uzp1 v10.8H, v15.8H, v11.8H // .............................*............................................. + uzp2 v15.8H, v15.8H, v11.8H // ..............................*............................................ + smlal v9.4S, v8.4H, v10.4H // ...............................*........................................... + smlal2 v16.4S, v8.8H, v10.8H // ................................*.......................................... + smlal v30.4S, v8.4H, v15.4H // .................................*......................................... + smlal2 v7.4S, v8.8H, v15.8H // ..................................*........................................ + smlal v9.4S, v20.4H, v27.4H // ...................................*....................................... + smlal2 v16.4S, v20.8H, v27.8H // ....................................*...................................... + smlal v30.4S, v20.4H, v10.4H // .....................................*..................................... + smlal2 v7.4S, v20.8H, v10.8H // ......................................*.................................... + ldr q20, [x7], #32 // .........*................................................................. + ldr q15, [x7, #-16] // ..........*................................................................ + ldr q8, [x8], #32 // ...........*............................................................... + uzp1 v11.8H, v20.8H, v15.8H // .......................................*................................... + uzp2 v20.8H, v20.8H, v15.8H // ........................................*.................................. + ldr q15, [x8, #-16] // ............*.............................................................. + ld1 {v27.8H}, [x9], #16 // .............*............................................................. + uzp1 v10.8H, v8.8H, v15.8H // .........................................*................................. + uzp2 v15.8H, v8.8H, v15.8H // ..........................................*................................ + smlal v9.4S, v11.4H, v10.4H // ...........................................*............................... + smlal2 v16.4S, v11.8H, v10.8H // ............................................*.............................. + smlal v30.4S, v11.4H, v15.4H // .............................................*............................. + smlal2 v7.4S, v11.8H, v15.8H // ..............................................*............................ + smlal v9.4S, v20.4H, v27.4H // ...............................................*........................... + smlal2 v16.4S, v20.8H, v27.8H // ................................................*.......................... + smlal v30.4S, v20.4H, v10.4H // .................................................*......................... + smlal2 v7.4S, v20.8H, v10.8H // ..................................................*........................ + ldr q15, [x2], #32 // ...............................................................*........... + uzp1 v20.8H, v9.8H, v16.8H // ....................................................*...................... + uzp1 v8.8H, v30.8H, v7.8H // .....................................................*..................... + mul v20.8H, v20.8H, v2.8H // ......................................................*.................... + mul v8.8H, v8.8H, v2.8H // .......................................................*................... + ldr q21, [x4], #32 // .................................................................*......... + smlal v9.4S, v20.4H, v0.4H // ........................................................*.................. + smlal2 v16.4S, v20.8H, v0.8H // .........................................................*................. + smlal v30.4S, v8.4H, v0.4H // ..........................................................*................ + smlal2 v7.4S, v8.8H, v0.8H // ...........................................................*............... + ldr q6, [x4, #-16] // ..................................................................*........ + uzp2 v27.8H, v9.8H, v16.8H // ............................................................*.............. + uzp2 v10.8H, v30.8H, v7.8H // .............................................................*............. + ldr q16, [x2, #-16] // ...................................................*....................... + ldr q30, [x1, #16] // ..............................................................*............ + ld1 {v9.8H}, [x3], #16 // ................................................................*.......... + ldr q1, [x5], #32 // ...................................................................*....... + ldr q12, [x5, #-16] // ....................................................................*...... + ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + ldr q19, [x7], #32 // ......................................................................*.... + ldr q31, [x7, #-16] // .......................................................................*... + ldr q17, [x8], #32 // ........................................................................*.. + ldr q18, [x8, #-16] // .........................................................................*. + ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q16, [x2, #16] // *.......................................................................... + // ldr q30, [x1, #16] // ..*........................................................................ + // ldr q15, [x2], #32 // .*......................................................................... + // ld1 {v9.8H}, [x3], #16 // .....*..................................................................... + // ldr q21, [x4], #32 // .......*................................................................... + // ldr q6, [x4, #-16] // ..................*........................................................ + // ldr q1, [x5], #32 // ...................*....................................................... + // ldr q12, [x5, #-16] // ......................*.................................................... + // ld1 {v24.8H}, [x6], #16 // .......................*................................................... + // ldr q19, [x7], #32 // ..................................*........................................ + // ldr q31, [x7, #-16] // ...................................*....................................... + // ldr q17, [x8], #32 // ....................................*...................................... + // ldr q18, [x8, #-16] // .......................................*................................... + // ld1 {v25.8H}, [x9], #16 // ........................................*.................................. + // ldr q20, [x1], #32 // ......*.................................................................... + // uzp1 v7.8H, v15.8H, v16.8H // ...*....................................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ....*...................................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ........*.................................................................. + // uzp2 v20.8H, v20.8H, v30.8H // .........*................................................................. + // smull v30.4S, v8.4H, v15.4H // ..........*................................................................ + // smull2 v15.4S, v8.8H, v15.8H // ...........*............................................................... + // smull v11.4S, v8.4H, v7.4H // ............*.............................................................. + // smull2 v8.4S, v8.8H, v7.8H // .............*............................................................. + // smlal v30.4S, v20.4H, v7.4H // ..............*............................................................ + // smlal2 v15.4S, v20.8H, v7.8H // ...............*........................................................... + // smlal v11.4S, v20.4H, v9.4H // ................*.......................................................... + // smlal2 v8.4S, v20.8H, v9.8H // .................*......................................................... + // uzp1 v7.8H, v21.8H, v6.8H // ....................*...................................................... + // uzp2 v20.8H, v21.8H, v6.8H // .....................*..................................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........................*.................................................. + // uzp2 v9.8H, v1.8H, v12.8H // .........................*................................................. + // smlal v11.4S, v7.4H, v16.4H // ..........................*................................................ + // smlal2 v8.4S, v7.8H, v16.8H // ...........................*............................................... + // smlal v30.4S, v7.4H, v9.4H // ............................*.............................................. + // smlal2 v15.4S, v7.8H, v9.8H // .............................*............................................. + // smlal v11.4S, v20.4H, v24.4H // ..............................*............................................ + // smlal2 v8.4S, v20.8H, v24.8H // ...............................*........................................... + // smlal v30.4S, v20.4H, v16.4H // ................................*.......................................... + // smlal2 v15.4S, v20.8H, v16.8H // .................................*......................................... + // uzp1 v7.8H, v19.8H, v31.8H // .....................................*..................................... + // uzp2 v20.8H, v19.8H, v31.8H // ......................................*.................................... + // uzp1 v16.8H, v17.8H, v18.8H // .........................................*................................. + // uzp2 v9.8H, v17.8H, v18.8H // ..........................................*................................ + // smlal v11.4S, v7.4H, v16.4H // ...........................................*............................... + // smlal2 v8.4S, v7.8H, v16.8H // ............................................*.............................. + // smlal v30.4S, v7.4H, v9.4H // .............................................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ..............................................*............................ + // smlal v11.4S, v20.4H, v25.4H // ...............................................*........................... + // smlal2 v8.4S, v20.8H, v25.8H // ................................................*.......................... + // smlal v30.4S, v20.4H, v16.4H // .................................................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ..................................................*........................ + // ldr q16, [x2, #16] // ................................................................*.......... + // uzp1 v7.8H, v11.8H, v8.8H // ....................................................*...................... + // uzp1 v20.8H, v30.8H, v15.8H // .....................................................*..................... + // mul v7.8H, v7.8H, v2.8H // ......................................................*.................... + // mul v20.8H, v20.8H, v2.8H // .......................................................*................... + // smlal v11.4S, v7.4H, v0.4H // .........................................................*................. + // smlal2 v8.4S, v7.8H, v0.8H // ..........................................................*................ + // smlal v30.4S, v20.4H, v0.4H // ...........................................................*............... + // smlal2 v15.4S, v20.8H, v0.8H // ............................................................*.............. + // uzp2 v27.8H, v11.8H, v8.8H // ..............................................................*............ + // uzp2 v10.8H, v30.8H, v15.8H // ...............................................................*........... + // ldr q30, [x1, #16] // .................................................................*......... + // ldr q15, [x2], #32 // ...................................................*....................... + // ld1 {v9.8H}, [x3], #16 // ..................................................................*........ + // ldr q21, [x4], #32 // ........................................................*.................. + // ldr q6, [x4, #-16] // .............................................................*............. + // ldr q1, [x5], #32 // ...................................................................*....... + // ldr q12, [x5, #-16] // ....................................................................*...... + // ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + // ldr q19, [x7], #32 // ......................................................................*.... + // ldr q31, [x7, #-16] // .......................................................................*... + // ldr q17, [x8], #32 // ........................................................................*.. + // ldr q18, [x8, #-16] // .........................................................................*. + // ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + sub count, count, #2 +1: + // Instructions: 65 + // Expected cycles: 80 + // Expected IPC: 0.81 + + // Cycle bound: 80.0 + // IPC bound: 0.81 + + // Wall time: 11.64s + // User time: 11.64s + + // ---------------------- original position -----------------------> + // 0 25 50 + // |------------------------|------------------------|-------------- + ldr q20, [x1], #32 // *................................................................ + uzp1 v7.8H, v15.8H, v16.8H // ......*.......................................................... + uzp2 v15.8H, v15.8H, v16.8H // .......*......................................................... + uzp1 v8.8H, v20.8H, v30.8H // ..*.............................................................. + uzp2 v20.8H, v20.8H, v30.8H // ...*............................................................. + smull v30.4S, v8.4H, v15.4H // .............*................................................... + smull2 v15.4S, v8.8H, v15.8H // ..............*.................................................. + smull v11.4S, v8.4H, v7.4H // .........*....................................................... + smull2 v8.4S, v8.8H, v7.8H // ..........*...................................................... + smlal v30.4S, v20.4H, v7.4H // ...............*................................................. + smlal2 v15.4S, v20.8H, v7.8H // ................*................................................ + smlal v11.4S, v20.4H, v9.4H // ...........*..................................................... + smlal2 v8.4S, v20.8H, v9.8H // ............*.................................................... + uzp1 v7.8H, v21.8H, v6.8H // ...................*............................................. + uzp2 v20.8H, v21.8H, v6.8H // ....................*............................................ + uzp1 v16.8H, v1.8H, v12.8H // .......................*......................................... + uzp2 v9.8H, v1.8H, v12.8H // ........................*........................................ + smlal v11.4S, v7.4H, v16.4H // ..........................*...................................... + smlal2 v8.4S, v7.8H, v16.8H // ...........................*..................................... + smlal v30.4S, v7.4H, v9.4H // ..............................*.................................. + smlal2 v15.4S, v7.8H, v9.8H // ...............................*................................. + smlal v11.4S, v20.4H, v24.4H // ............................*.................................... + smlal2 v8.4S, v20.8H, v24.8H // .............................*................................... + smlal v30.4S, v20.4H, v16.4H // ................................*................................ + smlal2 v15.4S, v20.8H, v16.8H // .................................*............................... + uzp1 v7.8H, v19.8H, v31.8H // ....................................*............................ + uzp2 v20.8H, v19.8H, v31.8H // .....................................*........................... + uzp1 v16.8H, v17.8H, v18.8H // ........................................*........................ + uzp2 v9.8H, v17.8H, v18.8H // .........................................*....................... + smlal v11.4S, v7.4H, v16.4H // ...........................................*..................... + smlal2 v8.4S, v7.8H, v16.8H // ............................................*.................... + smlal v30.4S, v7.4H, v9.4H // ...............................................*................. + smlal2 v15.4S, v7.8H, v9.8H // ................................................*................ + smlal v11.4S, v20.4H, v25.4H // .............................................*................... + smlal2 v8.4S, v20.8H, v25.8H // ..............................................*.................. + smlal v30.4S, v20.4H, v16.4H // .................................................*............... + smlal2 v15.4S, v20.8H, v16.8H // ..................................................*.............. + ldr q16, [x2, #16] // .....e........................................................... + uzp1 v7.8H, v11.8H, v8.8H // ...................................................*............. + uzp1 v20.8H, v30.8H, v15.8H // ........................................................*........ + mul v7.8H, v7.8H, v2.8H // ....................................................*............ + mul v20.8H, v20.8H, v2.8H // .........................................................*....... + zip2 v9.8H, v27.8H, v10.8H // ..............................................................l.. + zip1 v27.8H, v27.8H, v10.8H // .............................................................l... + smlal v11.4S, v7.4H, v0.4H // .....................................................*........... + smlal2 v8.4S, v7.8H, v0.8H // ......................................................*.......... + smlal v30.4S, v20.4H, v0.4H // ..........................................................*...... + smlal2 v15.4S, v20.8H, v0.8H // ...........................................................*..... + str q27, [x0], #32 // ...............................................................l. + uzp2 v27.8H, v11.8H, v8.8H // .......................................................*......... + str q9, [x0, #-16] // ................................................................l + uzp2 v10.8H, v30.8H, v15.8H // ............................................................*.... + ldr q30, [x1, #16] // .e............................................................... + ldr q15, [x2], #32 // ....e............................................................ + ld1 {v9.8H}, [x3], #16 // ........e........................................................ + ldr q21, [x4], #32 // .................e............................................... + ldr q6, [x4, #-16] // ..................e.............................................. + ldr q1, [x5], #32 // .....................e........................................... + ldr q12, [x5, #-16] // ......................e.......................................... + ld1 {v24.8H}, [x6], #16 // .........................e....................................... + ldr q19, [x7], #32 // ..................................e.............................. + ldr q31, [x7, #-16] // ...................................e............................. + ldr q17, [x8], #32 // ......................................e.......................... + ldr q18, [x8, #-16] // .......................................e......................... + ld1 {v25.8H}, [x9], #16 // ..........................................e...................... + + // ---------------------------------------------------------------- new position -----------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------ + // ldr q12, [x1], #32 // ............................*................................................................~.................................................. + // ldr q13, [x1, #-16] // ...............e............'...................................................~............'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'..*.............................................................'..~............................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'...*............................................................'...~.............................................. + // ldr q12, [x2], #32 // ................e...........'....................................................~...........'.................................................. + // ldr q13, [x2, #-16] // e...........................'....................................~...........................'....................................~............. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'*...............................................................'~................................................. + // uzp2 v6.8h, v12.8h, v13.8h // ............................'.*..............................................................'.~................................................ + // ld1 {v7.8h}, [x3], #16 // .................e..........'.....................................................~..........'.................................................. + // smull v8.4s, v3.4h, v5.4h // ............................'......*.........................................................'......~........................................... + // smull2 v10.4s, v3.8h, v5.8h // ............................'.......*........................................................'.......~.......................................... + // smlal v8.4s, v4.4h, v7.4h // ............................'..........*.....................................................'..........~....................................... + // smlal2 v10.4s, v4.8h, v7.8h // ............................'...........*....................................................'...........~...................................... + // smull v9.4s, v3.4h, v6.4h // ............................'....*...........................................................'....~............................................. + // smull2 v11.4s, v3.8h, v6.8h // ............................'.....*..........................................................'.....~............................................ + // smlal v9.4s, v4.4h, v5.4h // ............................'........*.......................................................'........~......................................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.........*......................................................'.........~........................................ + // ldr q12, [x4], #32 // ..................e.........'......................................................~.........'.................................................. + // ldr q13, [x4, #-16] // ...................e........'.......................................................~........'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'............*...................................................'............~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.............*..................................................'.............~.................................... + // ldr q12, [x5], #32 // ....................e.......'........................................................~.......'.................................................. + // ldr q13, [x5, #-16] // .....................e......'.........................................................~......'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..............*.................................................'..............~................................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...............*................................................'...............~.................................. + // ld1 {v7.8h}, [x6], #16 // ......................e.....'..........................................................~.....'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'................*...............................................'................~................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.................*..............................................'.................~................................ + // smlal v8.4s, v4.4h, v7.4h // ............................'....................*...........................................'....................~............................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.....................*..........................................'.....................~............................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..................*.............................................'..................~............................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...................*............................................'...................~.............................. + // smlal v9.4s, v4.4h, v5.4h // ............................'......................*.........................................'......................~........................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.......................*........................................'.......................~.......................... + // ldr q12, [x7], #32 // .......................e....'...........................................................~....'.................................................. + // ldr q13, [x7, #-16] // ........................e...'............................................................~...'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'........................*.......................................'........................~......................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.........................*......................................'.........................~........................ + // ldr q12, [x8], #32 // .........................e..'.............................................................~..'.................................................. + // ldr q13, [x8, #-16] // ..........................e.'..............................................................~.'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..........................*.....................................'..........................~....................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...........................*....................................'...........................~...................... + // ld1 {v7.8h}, [x9], #16 // ...........................e'...............................................................~'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'............................*...................................'............................~..................... + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.............................*..................................'.............................~.................... + // smlal v8.4s, v4.4h, v7.4h // ............................'................................*...............................'................................~................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.................................*..............................'.................................~................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..............................*.................................'..............................~................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...............................*................................'...............................~.................. + // smlal v9.4s, v4.4h, v5.4h // ............................'..................................*.............................'..................................~............... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'...................................*............................'...................................~.............. + // uzp1 v28.8h, v8.8h, v10.8h // .~..........................'.....................................*..........................'.....................................~............ + // mul v28.8h, v28.8h, v2.8h // ...~........................'.......................................*........................'.......................................~.......... + // smlal v8.4s, v28.4h, v0.4h // .......~....................'...........................................*....................'...........................................~...... + // smlal2 v10.4s, v28.8h, v0.8h // ........~...................'............................................*...................'............................................~..... + // uzp2 v26.8h, v8.8h, v10.8h // ............~...............'................................................*...............'................................................~. + // uzp1 v28.8h, v9.8h, v11.8h // ..~.........................'......................................*.........................'......................................~........... + // mul v28.8h, v28.8h, v2.8h // ....~.......................'........................................*.......................'........................................~......... + // smlal v9.4s, v28.4h, v0.4h // .........~..................'.............................................*..................'.............................................~.... + // smlal2 v11.4s, v28.8h, v0.8h // ..........~.................'..............................................*.................'..............................................~... + // uzp2 v27.8h, v9.8h, v11.8h // ..............~.............'..................................................*.............'.................................................. + // zip1 v12.8h, v26.8h, v27.8h // ......~.....................'..........................................~.....................'..........................................l....... + // zip2 v13.8h, v26.8h, v27.8h // .....~......................'.........................................~......................'.........................................l........ + // str q12, [x0], #32 // ...........~................'...............................................~................'...............................................l.. + // str q13, [x0, #-16] // .............~..............'.................................................~..............'.................................................l + + sub count, count, #1 + cbnz count, 1b + // Instructions: 55 + // Expected cycles: 61 + // Expected IPC: 0.90 + + // Cycle bound: 61.0 + // IPC bound: 0.90 + + // Wall time: 8.41s + // User time: 8.41s + + // ----------------- original position ------------------> + // 0 25 50 + // |------------------------|------------------------|---- + ldr q7, [x1], #32 // *...................................................... + uzp1 v20.8H, v15.8H, v16.8H // .*..................................................... + uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + uzp1 v23.8H, v7.8H, v30.8H // ...*................................................... + uzp2 v11.8H, v7.8H, v30.8H // ....*.................................................. + smull2 v8.4S, v23.8H, v20.8H // ........*.............................................. + smull v5.4S, v23.4H, v20.4H // .......*............................................... + smull2 v30.4S, v23.8H, v15.8H // ......*................................................ + uzp1 v28.8H, v1.8H, v12.8H // ...............*....................................... + smlal2 v8.4S, v11.8H, v9.8H // ............*.......................................... + smlal v5.4S, v11.4H, v9.4H // ...........*........................................... + uzp1 v3.8H, v21.8H, v6.8H // .............*......................................... + smull v16.4S, v23.4H, v15.4H // .....*................................................. + smlal2 v8.4S, v3.8H, v28.8H // ..................*.................................... + smlal v5.4S, v3.4H, v28.4H // .................*..................................... + uzp2 v29.8H, v21.8H, v6.8H // ..............*........................................ + uzp1 v7.8H, v17.8H, v18.8H // ...........................*........................... + smlal2 v8.4S, v29.8H, v24.8H // ......................*................................ + uzp1 v14.8H, v19.8H, v31.8H // .........................*............................. + smlal v16.4S, v11.4H, v20.4H // .........*............................................. + smlal2 v30.4S, v11.8H, v20.8H // ..........*............................................ + smlal2 v8.4S, v14.8H, v7.8H // ..............................*........................ + uzp2 v20.8H, v1.8H, v12.8H // ................*...................................... + uzp2 v21.8H, v19.8H, v31.8H // ..........................*............................ + smlal2 v30.4S, v3.8H, v20.8H // ....................*.................................. + smlal v16.4S, v3.4H, v20.4H // ...................*................................... + smlal v5.4S, v29.4H, v24.4H // .....................*................................. + uzp2 v9.8H, v17.8H, v18.8H // ............................*.......................... + smlal2 v30.4S, v29.8H, v28.8H // ........................*.............................. + smlal v16.4S, v29.4H, v28.4H // .......................*............................... + smlal v5.4S, v14.4H, v7.4H // .............................*......................... + smlal2 v8.4S, v21.8H, v25.8H // ..................................*.................... + smlal2 v30.4S, v14.8H, v9.8H // ................................*...................... + smlal v16.4S, v14.4H, v9.4H // ...............................*....................... + smlal v5.4S, v21.4H, v25.4H // .................................*..................... + zip1 v20.8H, v27.8H, v10.8H // ..........................................*............ + smlal2 v30.4S, v21.8H, v7.8H // ....................................*.................. + smlal v16.4S, v21.4H, v7.4H // ...................................*................... + uzp1 v7.8H, v5.8H, v8.8H // .....................................*................. + str q20, [x0], #32 // ...............................................*....... + mul v15.8H, v7.8H, v2.8H // .......................................*............... + uzp1 v7.8H, v16.8H, v30.8H // ......................................*................ + zip2 v31.8H, v27.8H, v10.8H // .........................................*............. + mul v20.8H, v7.8H, v2.8H // ........................................*.............. + smlal v5.4S, v15.4H, v0.4H // ...........................................*........... + smlal2 v8.4S, v15.8H, v0.8H // ............................................*.......... + str q31, [x0, #-16] // .................................................*..... + smlal2 v30.4S, v20.8H, v0.8H // ..............................................*........ + smlal v16.4S, v20.4H, v0.4H // .............................................*......... + uzp2 v15.8H, v5.8H, v8.8H // ................................................*...... + uzp2 v20.8H, v16.8H, v30.8H // ..................................................*.... + zip1 v7.8H, v15.8H, v20.8H // ....................................................*.. + zip2 v20.8H, v15.8H, v20.8H // ...................................................*... + str q7, [x0], #32 // .....................................................*. + str q20, [x0, #-16] // ......................................................* + + // -------------------- new position --------------------> + // 0 25 50 + // |------------------------|------------------------|---- + // ldr q20, [x1], #32 // *...................................................... + // uzp1 v7.8H, v15.8H, v16.8H // .*..................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ...*................................................... + // uzp2 v20.8H, v20.8H, v30.8H // ....*.................................................. + // smull v30.4S, v8.4H, v15.4H // ............*.......................................... + // smull2 v15.4S, v8.8H, v15.8H // .......*............................................... + // smull v11.4S, v8.4H, v7.4H // ......*................................................ + // smull2 v8.4S, v8.8H, v7.8H // .....*................................................. + // smlal v30.4S, v20.4H, v7.4H // ...................*................................... + // smlal2 v15.4S, v20.8H, v7.8H // ....................*.................................. + // smlal v11.4S, v20.4H, v9.4H // ..........*............................................ + // smlal2 v8.4S, v20.8H, v9.8H // .........*............................................. + // uzp1 v7.8H, v21.8H, v6.8H // ...........*........................................... + // uzp2 v20.8H, v21.8H, v6.8H // ...............*....................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........*.............................................. + // uzp2 v9.8H, v1.8H, v12.8H // ......................*................................ + // smlal v11.4S, v7.4H, v16.4H // ..............*........................................ + // smlal2 v8.4S, v7.8H, v16.8H // .............*......................................... + // smlal v30.4S, v7.4H, v9.4H // .........................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ........................*.............................. + // smlal v11.4S, v20.4H, v24.4H // ..........................*............................ + // smlal2 v8.4S, v20.8H, v24.8H // .................*..................................... + // smlal v30.4S, v20.4H, v16.4H // .............................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ............................*.......................... + // uzp1 v7.8H, v19.8H, v31.8H // ..................*.................................... + // uzp2 v20.8H, v19.8H, v31.8H // .......................*............................... + // uzp1 v16.8H, v17.8H, v18.8H // ................*...................................... + // uzp2 v9.8H, v17.8H, v18.8H // ...........................*........................... + // smlal v11.4S, v7.4H, v16.4H // ..............................*........................ + // smlal2 v8.4S, v7.8H, v16.8H // .....................*................................. + // smlal v30.4S, v7.4H, v9.4H // .................................*..................... + // smlal2 v15.4S, v7.8H, v9.8H // ................................*...................... + // smlal v11.4S, v20.4H, v25.4H // ..................................*.................... + // smlal2 v8.4S, v20.8H, v25.8H // ...............................*....................... + // smlal v30.4S, v20.4H, v16.4H // .....................................*................. + // smlal2 v15.4S, v20.8H, v16.8H // ....................................*.................. + // uzp1 v7.8H, v11.8H, v8.8H // ......................................*................ + // uzp1 v20.8H, v30.8H, v15.8H // .........................................*............. + // mul v7.8H, v7.8H, v2.8H // ........................................*.............. + // mul v20.8H, v20.8H, v2.8H // ...........................................*........... + // zip2 v9.8H, v27.8H, v10.8H // ..........................................*............ + // zip1 v27.8H, v27.8H, v10.8H // ...................................*................... + // smlal v11.4S, v7.4H, v0.4H // ............................................*.......... + // smlal2 v8.4S, v7.8H, v0.8H // .............................................*......... + // smlal v30.4S, v20.4H, v0.4H // ................................................*...... + // smlal2 v15.4S, v20.8H, v0.8H // ...............................................*....... + // str q27, [x0], #32 // .......................................*............... + // uzp2 v27.8H, v11.8H, v8.8H // .................................................*..... + // str q9, [x0, #-16] // ..............................................*........ + // uzp2 v10.8H, v30.8H, v15.8H // ..................................................*.... + // zip2 v9.8H, v27.8H, v10.8H // ....................................................*.. + // zip1 v27.8H, v27.8H, v10.8H // ...................................................*... + // str q27, [x0], #32 // .....................................................*. + // str q9, [x0, #-16] // ......................................................* + + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) + // Instructions: 114 + // Expected cycles: 153 + // Expected IPC: 0.75 + // + // Cycle bound: 153.0 + // IPC bound: 0.75 + // + // Wall time: 0.69s + // User time: 0.69s + // + // ----------------------------------------------- original position -----------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + ldr q23, [x2, #16] // .*................................................................................................................ + ldr q19, [x2], #32 // *................................................................................................................. + ldr q17, [x5], #32 // ..*............................................................................................................... + uzp2 v13.8H, v19.8H, v23.8H // ..........*....................................................................................................... + uzp1 v19.8H, v19.8H, v23.8H // ...........*...................................................................................................... + ldr q23, [x5, #-16] // ...*.............................................................................................................. + ldr q30, [x1, #16] // .....*............................................................................................................ + uzp2 v9.8H, v17.8H, v23.8H // ....*............................................................................................................. + uzp1 v23.8H, v17.8H, v23.8H // .......*.......................................................................................................... + ldr q17, [x1], #32 // ......*........................................................................................................... + ldr q10, [x7, #16] // .............*.................................................................................................... + uzp1 v12.8H, v17.8H, v30.8H // ........*......................................................................................................... + uzp2 v17.8H, v17.8H, v30.8H // .........*........................................................................................................ + smull2 v30.4S, v12.8H, v13.8H // ............*..................................................................................................... + smull v13.4S, v12.4H, v13.4H // ............................................*..................................................................... + smull2 v22.4S, v12.8H, v19.8H // .....................................*............................................................................ + smull v12.4S, v12.4H, v19.4H // ..........................................*....................................................................... + smlal2 v30.4S, v17.8H, v19.8H // ...............................*.................................................................................. + smlal v13.4S, v17.4H, v19.4H // ...............................................*.................................................................. + ldr q19, [x4], #32 // ....................*............................................................................................. + ldr q16, [x4, #-16] // .....................*............................................................................................ + ld1 {v8.8H}, [x3], #16 // ................................*................................................................................. + uzp1 v26.8H, v19.8H, v16.8H // .......................*.......................................................................................... + uzp2 v19.8H, v19.8H, v16.8H // ........................*......................................................................................... + smlal2 v30.4S, v26.8H, v9.8H // .................................*................................................................................ + smlal v13.4S, v26.4H, v9.4H // ..................................................*............................................................... + smlal2 v22.4S, v17.8H, v8.8H // ........................................*......................................................................... + smlal v12.4S, v17.4H, v8.4H // .................................................*................................................................ + smlal2 v30.4S, v19.8H, v23.8H // ...................................*.............................................................................. + smlal v13.4S, v19.4H, v23.4H // .......................................................*.......................................................... + smlal2 v22.4S, v26.8H, v23.8H // ...........................................*...................................................................... + smlal v12.4S, v26.4H, v23.4H // .....................................................*............................................................ + ldr q23, [x7], #32 // ......................*........................................................................................... + ldr q17, [x8, #16] // ..............*................................................................................................... + uzp1 v9.8H, v23.8H, v10.8H // ..........................*....................................................................................... + uzp2 v23.8H, v23.8H, v10.8H // ....................................*............................................................................. + ldr q10, [x10], #32 // ...............*.................................................................................................. + ldr q16, [x10, #-16] // ................*................................................................................................. + ld1 {v8.8H}, [x12], #16 // .................*................................................................................................ + uzp1 v26.8H, v10.8H, v16.8H // ..................*............................................................................................... + uzp2 v10.8H, v10.8H, v16.8H // ...................*.............................................................................................. + ld1 {v16.8H}, [x6], #16 // .........................*........................................................................................ + ldr q3, [x11, #16] // ...........................*...................................................................................... + smlal2 v22.4S, v19.8H, v16.8H // ..............................................*................................................................... + smlal v12.4S, v19.4H, v16.4H // ........................................................*......................................................... + ldr q19, [x11], #32 // ............................*..................................................................................... + ld1 {v16.8H}, [x9], #16 // .............................*.................................................................................... + uzp1 v4.8H, v19.8H, v3.8H // ..................................*............................................................................... + uzp2 v19.8H, v19.8H, v3.8H // .......................................*.......................................................................... + ldr q3, [x8], #32 // ..............................*................................................................................... + ldr q31, [x2], #32 // ......................................*........................................................................... + uzp1 v6.8H, v3.8H, v17.8H // ...................................................*.............................................................. + uzp2 v17.8H, v3.8H, v17.8H // .........................................................*........................................................ + smlal2 v22.4S, v9.8H, v6.8H // ..........................................................*....................................................... + smlal2 v30.4S, v9.8H, v17.8H // ...........................................................*...................................................... + smlal v13.4S, v9.4H, v17.4H // ............................................................*..................................................... + smlal v12.4S, v9.4H, v6.4H // .............................................................*.................................................... + smlal2 v22.4S, v23.8H, v16.8H // ..............................................................*................................................... + smlal2 v30.4S, v23.8H, v6.8H // ...............................................................*.................................................. + smlal v13.4S, v23.4H, v6.4H // ................................................................*................................................. + smlal v12.4S, v23.4H, v16.4H // .................................................................*................................................ + smlal2 v22.4S, v26.8H, v4.8H // ..................................................................*............................................... + smlal2 v30.4S, v26.8H, v19.8H // ...................................................................*.............................................. + smlal v13.4S, v26.4H, v19.4H // ....................................................................*............................................. + smlal v12.4S, v26.4H, v4.4H // .....................................................................*............................................ + smlal2 v22.4S, v10.8H, v8.8H // ......................................................................*........................................... + smlal2 v30.4S, v10.8H, v4.8H // .......................................................................*.......................................... + smlal v13.4S, v10.4H, v4.4H // ........................................................................*......................................... + smlal v12.4S, v10.4H, v8.4H // .........................................................................*........................................ + ldr q19, [x2, #-16] // .........................................*........................................................................ + uzp1 v23.8H, v13.8H, v30.8H // ...........................................................................*...................................... + uzp1 v17.8H, v12.8H, v22.8H // ....................................................................................*............................. + mul v23.8H, v23.8H, v2.8H // .............................................................................*.................................... + uzp2 v21.8H, v31.8H, v19.8H // ................................................................................*................................. + uzp1 v19.8H, v31.8H, v19.8H // ...................................................................................*.............................. + mul v17.8H, v17.8H, v2.8H // .....................................................................................*............................ + smlal v13.4S, v23.4H, v0.4H // .................................................................................*................................ + smlal2 v30.4S, v23.8H, v0.8H // ..................................................................................*............................... + ldr q23, [x5], #32 // .............................................*.................................................................... + smlal2 v22.4S, v17.8H, v0.8H // ...........................................................................................................*...... + uzp2 v15.8H, v13.8H, v30.8H // ......................................................................................*........................... + smlal v12.4S, v17.4H, v0.4H // ............................................................................................................*..... + ldr q17, [x5, #-16] // ................................................*................................................................. + ldr q13, [x1, #16] // ......................................................*........................................................... + uzp2 v27.8H, v23.8H, v17.8H // ....................................................*............................................................. + uzp1 v28.8H, v23.8H, v17.8H // ............................................................................*..................................... + uzp2 v7.8H, v12.8H, v22.8H // ...............................................................................................................*.. + ldr q23, [x1], #32 // ..........................................................................*....................................... + zip1 v5.8H, v7.8H, v15.8H // .................................................................................................................* + ldr q3, [x7, #16] // ........................................................................................*......................... + uzp1 v31.8H, v23.8H, v13.8H // ..............................................................................*................................... + uzp2 v16.8H, v23.8H, v13.8H // ...............................................................................*.................................. + smull2 v24.4S, v31.8H, v21.8H // .......................................................................................*.......................... + ldr q6, [x8, #16] // .........................................................................................*........................ + ldr q23, [x10], #32 // ..........................................................................................*....................... + smlal2 v24.4S, v16.8H, v19.8H // ..........................................................................................................*....... + ldr q17, [x10, #-16] // ...........................................................................................*...................... + ld1 {v22.8H}, [x12], #16 // ............................................................................................*..................... + uzp1 v30.8H, v23.8H, v17.8H // .............................................................................................*.................... + uzp2 v11.8H, v23.8H, v17.8H // ..............................................................................................*................... + ldr q23, [x4], #32 // ...............................................................................................*.................. + ldr q17, [x4, #-16] // ................................................................................................*................. + ldr q4, [x7], #32 // .................................................................................................*................ + uzp1 v20.8H, v23.8H, v17.8H // ..................................................................................................*............... + uzp2 v26.8H, v23.8H, v17.8H // ...................................................................................................*.............. + uzp1 v9.8H, v4.8H, v3.8H // .....................................................................................................*............ + smlal2 v24.4S, v20.8H, v27.8H // ..............................................................................................................*... + ld1 {v8.8H}, [x6], #16 // ....................................................................................................*............. + ldr q25, [x11, #16] // ......................................................................................................*........... + ldr q29, [x11], #32 // .......................................................................................................*.......... + ld1 {v12.8H}, [x9], #16 // ........................................................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // ................................................................................................................*. + ldr q14, [x8], #32 // .........................................................................................................*........ + ld1 {v23.8H}, [x3], #16 // .............................................................................................................*.... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q3, [x2], #32 // .*................................................................................................................ + // ldr q17, [x2, #-16] // *................................................................................................................. + // ldr q21, [x5], #32 // ..*............................................................................................................... + // ldr q19, [x5, #-16] // .....*............................................................................................................ + // uzp2 v27.8H, v21.8H, v19.8H // .......*.......................................................................................................... + // ldr q25, [x1, #16] // ......*........................................................................................................... + // ldr q22, [x1], #32 // .........*........................................................................................................ + // uzp1 v28.8H, v21.8H, v19.8H // ........*......................................................................................................... + // uzp1 v31.8H, v22.8H, v25.8H // ...........*...................................................................................................... + // uzp2 v16.8H, v22.8H, v25.8H // ............*..................................................................................................... + // uzp2 v21.8H, v3.8H, v17.8H // ...*.............................................................................................................. + // uzp1 v19.8H, v3.8H, v17.8H // ....*............................................................................................................. + // smull2 v24.4S, v31.8H, v21.8H // .............*.................................................................................................... + // ldr q3, [x7, #16] // ..........*....................................................................................................... + // ldr q6, [x8, #16] // .................................*................................................................................ + // ldr q8, [x10], #32 // ....................................*............................................................................. + // ldr q26, [x10, #-16] // .....................................*............................................................................ + // ld1 {v22.8H}, [x12], #16 // ......................................*........................................................................... + // uzp1 v30.8H, v8.8H, v26.8H // .......................................*.......................................................................... + // uzp2 v11.8H, v8.8H, v26.8H // ........................................*......................................................................... + // ldr q8, [x4], #32 // ...................*.............................................................................................. + // ldr q26, [x4, #-16] // ....................*............................................................................................. + // ldr q4, [x7], #32 // ................................*................................................................................. + // uzp1 v20.8H, v8.8H, v26.8H // ......................*........................................................................................... + // uzp2 v26.8H, v8.8H, v26.8H // .......................*.......................................................................................... + // ld1 {v8.8H}, [x6], #16 // .........................................*........................................................................ + // uzp1 v9.8H, v4.8H, v3.8H // ..................................*............................................................................... + // ldr q25, [x11, #16] // ..........................................*....................................................................... + // ldr q29, [x11], #32 // .............................................*.................................................................... + // ld1 {v12.8H}, [x9], #16 // ..............................................*................................................................... + // ldr q14, [x8], #32 // .................................................*................................................................ + // smlal2 v24.4S, v16.8H, v19.8H // .................*................................................................................................ + // ld1 {v23.8H}, [x3], #16 // .....................*............................................................................................ + // smlal2 v24.4S, v20.8H, v27.8H // ........................*......................................................................................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................*.................................................................. + // smlal2 v24.4S, v26.8H, v28.8H // ............................*..................................................................................... + // uzp2 v4.8H, v4.8H, v3.8H // ...................................*.............................................................................. + // smull2 v13.4S, v31.8H, v19.8H // ...............*.................................................................................................. + // ldr q3, [x2], #32 // ..................................................*............................................................... + // uzp2 v1.8H, v29.8H, v25.8H // ................................................*................................................................. + // smlal2 v13.4S, v16.8H, v23.8H // ..........................*....................................................................................... + // ldr q17, [x2, #-16] // .....................................................................*............................................ + // smull v18.4S, v31.4H, v19.4H // ................*................................................................................................. + // smlal2 v13.4S, v20.8H, v28.8H // ..............................*................................................................................... + // smull v29.4S, v31.4H, v21.4H // ..............*................................................................................................... + // ldr q21, [x5], #32 // ..............................................................................*................................... + // smlal2 v13.4S, v26.8H, v8.8H // ...........................................*...................................................................... + // smlal v29.4S, v16.4H, v19.4H // ..................*............................................................................................... + // ldr q19, [x5, #-16] // ..................................................................................*............................... + // smlal v18.4S, v16.4H, v23.4H // ...........................*...................................................................................... + // smlal v29.4S, v20.4H, v27.4H // .........................*........................................................................................ + // uzp1 v31.8H, v14.8H, v6.8H // ...................................................*.............................................................. + // uzp2 v27.8H, v21.8H, v19.8H // ....................................................................................*............................. + // smlal v18.4S, v20.4H, v28.4H // ...............................*.................................................................................. + // ldr q25, [x1, #16] // ...................................................................................*.............................. + // smlal v29.4S, v26.4H, v28.4H // .............................*.................................................................................... + // smlal v18.4S, v26.4H, v8.4H // ............................................*..................................................................... + // uzp2 v26.8H, v14.8H, v6.8H // ....................................................*............................................................. + // smlal2 v13.4S, v9.8H, v31.8H // .....................................................*............................................................ + // smlal2 v24.4S, v9.8H, v26.8H // ......................................................*........................................................... + // smlal v29.4S, v9.4H, v26.4H // .......................................................*.......................................................... + // smlal v18.4S, v9.4H, v31.4H // ........................................................*......................................................... + // smlal2 v13.4S, v4.8H, v12.8H // .........................................................*........................................................ + // smlal2 v24.4S, v4.8H, v31.8H // ..........................................................*....................................................... + // smlal v29.4S, v4.4H, v31.4H // ...........................................................*...................................................... + // smlal v18.4S, v4.4H, v12.4H // ............................................................*..................................................... + // smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................................................... + // smlal2 v24.4S, v30.8H, v1.8H // ..............................................................*................................................... + // smlal v29.4S, v30.4H, v1.4H // ...............................................................*.................................................. + // smlal v18.4S, v30.4H, v10.4H // ................................................................*................................................. + // smlal2 v13.4S, v11.8H, v22.8H // .................................................................*................................................ + // smlal2 v24.4S, v11.8H, v10.8H // ..................................................................*............................................... + // smlal v29.4S, v11.4H, v10.4H // ...................................................................*.............................................. + // smlal v18.4S, v11.4H, v22.4H // ....................................................................*............................................. + // ldr q22, [x1], #32 // .......................................................................................*.......................... + // uzp1 v31.8H, v29.8H, v24.8H // ......................................................................*........................................... + // uzp1 v28.8H, v21.8H, v19.8H // .....................................................................................*............................ + // mul v19.8H, v31.8H, v2.8H // ........................................................................*......................................... + // uzp1 v31.8H, v22.8H, v25.8H // ..........................................................................................*....................... + // uzp2 v16.8H, v22.8H, v25.8H // ...........................................................................................*...................... + // uzp2 v21.8H, v3.8H, v17.8H // .........................................................................*........................................ + // smlal v29.4S, v19.4H, v0.4H // ............................................................................*..................................... + // smlal2 v24.4S, v19.8H, v0.8H // .............................................................................*.................................... + // uzp1 v19.8H, v3.8H, v17.8H // ..........................................................................*....................................... + // uzp1 v26.8H, v18.8H, v13.8H // .......................................................................*.......................................... + // mul v23.8H, v26.8H, v2.8H // ...........................................................................*...................................... + // uzp2 v15.8H, v29.8H, v24.8H // ................................................................................*................................. + // smull2 v24.4S, v31.8H, v21.8H // ............................................................................................*..................... + // ldr q3, [x7, #16] // .........................................................................................*........................ + // ldr q6, [x8, #16] // .............................................................................................*.................... + // ldr q8, [x10], #32 // ..............................................................................................*................... + // ldr q26, [x10, #-16] // ................................................................................................*................. + // ld1 {v22.8H}, [x12], #16 // .................................................................................................*................ + // uzp1 v30.8H, v8.8H, v26.8H // ..................................................................................................*............... + // uzp2 v11.8H, v8.8H, v26.8H // ...................................................................................................*.............. + // ldr q8, [x4], #32 // ....................................................................................................*............. + // ldr q26, [x4, #-16] // .....................................................................................................*............ + // ldr q4, [x7], #32 // ......................................................................................................*........... + // uzp1 v20.8H, v8.8H, v26.8H // .......................................................................................................*.......... + // uzp2 v26.8H, v8.8H, v26.8H // ........................................................................................................*......... + // ld1 {v8.8H}, [x6], #16 // ...........................................................................................................*...... + // uzp1 v9.8H, v4.8H, v3.8H // .........................................................................................................*........ + // ldr q25, [x11, #16] // ............................................................................................................*..... + // ldr q29, [x11], #32 // .............................................................................................................*.... + // ld1 {v12.8H}, [x9], #16 // ..............................................................................................................*... + // ldr q14, [x8], #32 // ................................................................................................................*. + // smlal2 v24.4S, v16.8H, v19.8H // ...............................................................................................*.................. + // smlal2 v13.4S, v23.8H, v0.8H // ...............................................................................*.................................. + // smlal v18.4S, v23.4H, v0.4H // .................................................................................*................................ + // ld1 {v23.8H}, [x3], #16 // .................................................................................................................* + // smlal2 v24.4S, v20.8H, v27.8H // ..........................................................................................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // ......................................................................................*........................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................................................................................*.. + // zip1 v5.8H, v7.8H, v15.8H // ........................................................................................*......................... + + sub count, count, #2 +1: + // Instructions: 82 + // Expected cycles: 102 + // Expected IPC: 0.80 + // + // Cycle bound: 102.0 + // IPC bound: 0.80 + // + // Wall time: 15.93s + // User time: 15.93s + // + // ------------------------------- original position -------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + smlal2 v24.4S, v26.8H, v28.8H // .................................*................................................ + uzp2 v4.8H, v4.8H, v3.8H // .....................................*............................................ + smull2 v13.4S, v31.8H, v19.8H // ..........*....................................................................... + ldr q3, [x2], #32 // ....e............................................................................. + uzp2 v1.8H, v29.8H, v25.8H // ..........................................................*....................... + smlal2 v13.4S, v16.8H, v23.8H // ............*..................................................................... + ldr q17, [x2, #-16] // .....e............................................................................ + smull v18.4S, v31.4H, v19.4H // .........*........................................................................ + smlal2 v13.4S, v20.8H, v28.8H // ...........................*...................................................... + smull v29.4S, v31.4H, v21.4H // .............*.................................................................... + ldr q21, [x5], #32 // .....................e............................................................ + smlal2 v13.4S, v26.8H, v8.8H // .............................*.................................................... + smlal v29.4S, v16.4H, v19.4H // ...............*.................................................................. + ldr q19, [x5, #-16] // ......................e........................................................... + smlal v18.4S, v16.4H, v23.4H // ...........*...................................................................... + smlal v29.4S, v20.4H, v27.4H // ..............................*................................................... + uzp1 v31.8H, v14.8H, v6.8H // ........................................*......................................... + uzp2 v27.8H, v21.8H, v19.8H // ........................e......................................................... + smlal v18.4S, v20.4H, v28.4H // ..........................*....................................................... + ldr q25, [x1, #16] // .e................................................................................ + smlal v29.4S, v26.4H, v28.4H // ................................*................................................. + smlal v18.4S, v26.4H, v8.4H // ............................*..................................................... + uzp2 v26.8H, v14.8H, v6.8H // .........................................*........................................ + smlal2 v13.4S, v9.8H, v31.8H // ............................................*..................................... + smlal2 v24.4S, v9.8H, v26.8H // ................................................*................................. + smlal v29.4S, v9.4H, v26.4H // ...............................................*.................................. + smlal v18.4S, v9.4H, v31.4H // ...........................................*...................................... + smlal2 v13.4S, v4.8H, v12.8H // ..............................................*................................... + smlal2 v24.4S, v4.8H, v31.8H // ..................................................*............................... + smlal v29.4S, v4.4H, v31.4H // .................................................*................................ + smlal v18.4S, v4.4H, v12.4H // .............................................*.................................... + smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................... + smlal2 v24.4S, v30.8H, v1.8H // .................................................................*................ + smlal v29.4S, v30.4H, v1.4H // ................................................................*................. + smlal v18.4S, v30.4H, v10.4H // ............................................................*..................... + smlal2 v13.4S, v11.8H, v22.8H // ...............................................................*.................. + smlal2 v24.4S, v11.8H, v10.8H // ...................................................................*.............. + smlal v29.4S, v11.4H, v10.4H // ..................................................................*............... + smlal v18.4S, v11.4H, v22.4H // ..............................................................*................... + ldr q22, [x1], #32 // e................................................................................. + uzp1 v31.8H, v29.8H, v24.8H // .........................................................................*........ + uzp1 v28.8H, v21.8H, v19.8H // .......................e.......................................................... + mul v19.8H, v31.8H, v2.8H // ..........................................................................*....... + uzp1 v31.8H, v22.8H, v25.8H // ..e............................................................................... + uzp2 v16.8H, v22.8H, v25.8H // ...e.............................................................................. + uzp2 v21.8H, v3.8H, v17.8H // .......e.......................................................................... + smlal v29.4S, v19.4H, v0.4H // ...........................................................................*...... + smlal2 v24.4S, v19.8H, v0.8H // ............................................................................*..... + uzp1 v19.8H, v3.8H, v17.8H // ......e........................................................................... + uzp1 v26.8H, v18.8H, v13.8H // ....................................................................*............. + zip2 v14.8H, v7.8H, v15.8H // ...............................................................................l.. + mul v23.8H, v26.8H, v2.8H // .....................................................................*............ + uzp2 v15.8H, v29.8H, v24.8H // .............................................................................*.... + smull2 v24.4S, v31.8H, v21.8H // ..............e................................................................... + str q14, [x0, #16] // .................................................................................l + ldr q3, [x7, #16] // ...................................e.............................................. + ldr q6, [x8, #16] // .......................................e.......................................... + ldr q8, [x10], #32 // ...................................................e.............................. + ldr q26, [x10, #-16] // ....................................................e............................. + ld1 {v22.8H}, [x12], #16 // ...........................................................e...................... + uzp1 v30.8H, v8.8H, v26.8H // .....................................................e............................ + uzp2 v11.8H, v8.8H, v26.8H // ......................................................e........................... + ldr q8, [x4], #32 // .................e................................................................ + ldr q26, [x4, #-16] // ..................e............................................................... + ldr q4, [x7], #32 // ..................................e............................................... + uzp1 v20.8H, v8.8H, v26.8H // ...................e.............................................................. + uzp2 v26.8H, v8.8H, v26.8H // ....................e............................................................. + ld1 {v8.8H}, [x6], #16 // .........................e........................................................ + uzp1 v9.8H, v4.8H, v3.8H // ....................................e............................................. + ldr q25, [x11, #16] // ........................................................e......................... + ldr q29, [x11], #32 // .......................................................e.......................... + ld1 {v12.8H}, [x9], #16 // ..........................................e....................................... + ldr q14, [x8], #32 // ......................................e........................................... + smlal2 v24.4S, v16.8H, v19.8H // ................e................................................................. + smlal2 v13.4S, v23.8H, v0.8H // .......................................................................*.......... + smlal v18.4S, v23.4H, v0.4H // ......................................................................*........... + ld1 {v23.8H}, [x3], #16 // ........e......................................................................... + smlal2 v24.4S, v20.8H, v27.8H // ...............................e.................................................. + uzp2 v7.8H, v18.8H, v13.8H // ........................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // .........................................................e........................ + str q5, [x0], #32 // ................................................................................l. + zip1 v5.8H, v7.8H, v15.8H // ..............................................................................*... + + // ----------------------------------------------------------------------------------------------------------------- new position ------------------------------------------------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 175 200 225 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|---------------- + // ldr q12, [x1], #32 // ....................................e..........................................'......................................~..........................................'......................................~......................................... + // ldr q13, [x1, #-16] // ................e..............................................................'..................~..............................................................'..................~............................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................................e......................................'..........................................~......................................'..........................................~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // .........................................e.....................................'...........................................~.....................................'...........................................~.................................... + // ldr q12, [x2], #32 // e..............................................................................'..~..............................................................................'..~............................................................................. + // ldr q13, [x2, #-16] // ...e...........................................................................'.....~...........................................................................'.....~.......................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // .............................................e.................................'...............................................~.................................'...............................................~................................ + // uzp2 v6.8h, v12.8h, v13.8h // ..........................................e....................................'............................................~....................................'............................................~................................... + // ld1 {v7.8h}, [x3], #16 // .........................................................................e.....'...........................................................................~.....'...........................................................................~.... + // smull v8.4s, v3.4h, v5.4h // ....~..........................................................................'......*..........................................................................'......~......................................................................... + // smull2 v10.4s, v3.8h, v5.8h // ...............................................................................'.*...............................................................................'.~.............................................................................. + // smlal v8.4s, v4.4h, v7.4h // ...........~...................................................................'.............*...................................................................'.............~.................................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................................................'....*............................................................................'....~........................................................................... + // smull v9.4s, v3.4h, v6.4h // ......~........................................................................'........*........................................................................'........~....................................................................... + // smull2 v11.4s, v3.8h, v6.8h // ..................................................e............................'....................................................~............................'....................................................~........................... + // smlal v9.4s, v4.4h, v5.4h // .........~.....................................................................'...........*.....................................................................'...........~.................................................................... + // smlal2 v11.4s, v4.8h, v5.8h // ......................................................................e........'........................................................................~........'........................................................................~....... + // ldr q12, [x4], #32 // ...........................................................e...................'.............................................................~...................'.............................................................~.................. + // ldr q13, [x4, #-16] // ............................................................e..................'..............................................................~..................'..............................................................~................. + // uzp1 v3.8h, v12.8h, v13.8h // ..............................................................e................'................................................................~................'................................................................~............... + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................e...............'.................................................................~...............'.................................................................~.............. + // ldr q12, [x5], #32 // .......e.......................................................................'.........~.......................................................................'.........~...................................................................... + // ldr q13, [x5, #-16] // ..........e....................................................................'............~....................................................................'............~................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // ......................................e........................................'........................................~........................................'........................................~....................................... + // uzp2 v6.8h, v12.8h, v13.8h // ..............e................................................................'................~................................................................'................~............................................................... + // ld1 {v7.8h}, [x6], #16 // ................................................................e..............'..................................................................~..............'..................................................................~............. + // smlal v8.4s, v3.4h, v5.4h // ...............~...............................................................'.................*...............................................................'.................~.............................................................. + // smlal2 v10.4s, v3.8h, v5.8h // .....~.........................................................................'.......*.........................................................................'.......~........................................................................ + // smlal v8.4s, v4.4h, v7.4h // ..................~............................................................'....................*............................................................'....................~........................................................... + // smlal2 v10.4s, v4.8h, v7.8h // ........~......................................................................'..........*......................................................................'..........~..................................................................... + // smlal v9.4s, v3.4h, v6.4h // ............~..................................................................'..............*..................................................................'..............~................................................................. + // smlal2 v11.4s, v3.8h, v6.8h // ..........................................................................e....'............................................................................~....'............................................................................~... + // smlal v9.4s, v4.4h, v5.4h // .................~.............................................................'...................*.............................................................'...................~............................................................ + // smlal2 v11.4s, v4.8h, v5.8h // ...............................................................................*.................................................................................~................................................................................ + // ldr q12, [x7], #32 // .............................................................e.................'...............................................................~.................'...............................................................~................ + // ldr q13, [x7, #-16] // ....................................................e..........................'......................................................~..........................'......................................................~......................... + // uzp1 v3.8h, v12.8h, v13.8h // .................................................................e.............'...................................................................~.............'...................................................................~............ + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................................'*................................................................................'~............................................................................... + // ldr q12, [x8], #32 // .....................................................................e.........'.......................................................................~.........'.......................................................................~........ + // ldr q13, [x8, #-16] // .....................................................e.........................'.......................................................~.........................'.......................................................~........................ + // uzp1 v5.8h, v12.8h, v13.8h // .............~.................................................................'...............*.................................................................'...............~................................................................ + // uzp2 v6.8h, v12.8h, v13.8h // ...................~...........................................................'.....................*...........................................................'.....................~.......................................................... + // ld1 {v7.8h}, [x9], #16 // ....................................................................e..........'......................................................................~..........'......................................................................~......... + // smlal v8.4s, v3.4h, v5.4h // .......................~.......................................................'.........................*.......................................................'.........................~...................................................... + // smlal2 v10.4s, v3.8h, v5.8h // ....................~..........................................................'......................*..........................................................'......................~......................................................... + // smlal v8.4s, v4.4h, v7.4h // ...........................~...................................................'.............................*...................................................'.............................~.................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ........................~......................................................'..........................*......................................................'..........................~..................................................... + // smlal v9.4s, v3.4h, v6.4h // ......................~........................................................'........................*........................................................'........................~....................................................... + // smlal2 v11.4s, v3.8h, v6.8h // .....................~.........................................................'.......................*.........................................................'.......................~........................................................ + // smlal v9.4s, v4.4h, v5.4h // ..........................~....................................................'............................*....................................................'............................~................................................... + // smlal2 v11.4s, v4.8h, v5.8h // .........................~.....................................................'...........................*.....................................................'...........................~.................................................... + // ldr q12, [x10], #32 // ......................................................e........................'........................................................~........................'........................................................~....................... + // ldr q13, [x10, #-16] // .......................................................e.......................'.........................................................~.......................'.........................................................~...................... + // uzp1 v3.8h, v12.8h, v13.8h // .........................................................e.....................'...........................................................~.....................'...........................................................~.................... + // uzp2 v4.8h, v12.8h, v13.8h // ..........................................................e....................'............................................................~....................'............................................................~................... + // ldr q12, [x11], #32 // ...................................................................e...........'.....................................................................~...........'.....................................................................~.......... + // ldr q13, [x11, #-16] // ..................................................................e............'....................................................................~............'....................................................................~........... + // uzp1 v5.8h, v12.8h, v13.8h // ............................................................................e..'..............................................................................~..'..............................................................................~. + // uzp2 v6.8h, v12.8h, v13.8h // .~.............................................................................'...*.............................................................................'...~............................................................................ + // ld1 {v7.8h}, [x12], #16 // ........................................................e......................'..........................................................~......................'..........................................................~..................... + // smlal v8.4s, v3.4h, v5.4h // ...............................~...............................................'.................................*...............................................'.................................~.............................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................~..................................................'..............................*..................................................'..............................~................................................. + // smlal v8.4s, v4.4h, v7.4h // ...................................~...........................................'.....................................*...........................................'.....................................~.......................................... + // smlal2 v10.4s, v4.8h, v7.8h // ................................~..............................................'..................................*..............................................'..................................~............................................. + // smlal v9.4s, v3.4h, v6.4h // ..............................~................................................'................................*................................................'................................~............................................... + // smlal2 v11.4s, v3.8h, v6.8h // .............................~.................................................'...............................*.................................................'...............................~................................................ + // smlal v9.4s, v4.4h, v5.4h // ..................................~............................................'....................................*............................................'....................................~........................................... + // smlal2 v11.4s, v4.8h, v5.8h // .................................~.............................................'...................................*.............................................'...................................~............................................ + // uzp1 v28.8h, v8.8h, v10.8h // ..............................................~................................'................................................*................................'................................................~............................... + // mul v28.8h, v28.8h, v2.8h // ................................................~..............................'..................................................*..............................'..................................................~............................. + // smlal v8.4s, v28.4h, v0.4h // ........................................................................~......'..........................................................................*......'..........................................................................~..... + // smlal2 v10.4s, v28.8h, v0.8h // .......................................................................~.......'.........................................................................*.......'.........................................................................~...... + // uzp2 v26.8h, v8.8h, v10.8h // ...........................................................................~...'.............................................................................*...'.............................................................................~.. + // uzp1 v28.8h, v9.8h, v11.8h // .....................................~.........................................'.......................................*.........................................'.......................................~........................................ + // mul v28.8h, v28.8h, v2.8h // .......................................~.......................................'.........................................*.......................................'.........................................~...................................... + // smlal v9.4s, v28.4h, v0.4h // ...........................................~...................................'.............................................*...................................'.............................................~.................................. + // smlal2 v11.4s, v28.8h, v0.8h // ............................................~..................................'..............................................*..................................'..............................................~................................. + // uzp2 v27.8h, v9.8h, v11.8h // .................................................~.............................'...................................................*.............................'...................................................~............................ + // zip1 v12.8h, v26.8h, v27.8h // ..............................................................................~'................................................................................*'................................................................................ + // zip2 v13.8h, v26.8h, v27.8h // ...............................................~...............................'.................................................~...............................'.................................................l.............................. + // str q12, [x0], #32 // .............................................................................~.'...............................................................................~.'...............................................................................l + // str q13, [x0, #-16] // ...................................................~...........................'.....................................................~...........................'.....................................................l.......................... + + sub count, count, #1 + cbnz count, 1b + // Instructions: 50 + // Expected cycles: 56 + // Expected IPC: 0.89 + // + // Cycle bound: 56.0 + // IPC bound: 0.89 + // + // Wall time: 4.16s + // User time: 4.16s + // + // --------------- original position ---------------> + // 0 25 + // |------------------------| + smull2 v17.4S, v31.8H, v19.8H // ..*............................................... + uzp2 v1.8H, v14.8H, v6.8H // ................*................................. + smull v18.4S, v31.4H, v21.4H // .......*.......................................... + smlal2 v24.4S, v26.8H, v28.8H // *................................................. + smlal2 v17.4S, v16.8H, v23.8H // ....*............................................. + smull v21.4S, v31.4H, v19.4H // .....*............................................ + smlal v18.4S, v16.4H, v19.4H // .........*........................................ + uzp2 v31.8H, v4.8H, v3.8H // .*................................................ + uzp1 v3.8H, v14.8H, v6.8H // ............*..................................... + smlal v21.4S, v16.4H, v23.4H // ..........*....................................... + smlal v18.4S, v20.4H, v27.4H // ...........*...................................... + uzp2 v14.8H, v29.8H, v25.8H // ...*.............................................. + smlal2 v17.4S, v20.8H, v28.8H // ......*........................................... + smlal v21.4S, v20.4H, v28.4H // .............*.................................... + smlal v18.4S, v26.4H, v28.4H // ..............*................................... + smlal2 v24.4S, v9.8H, v1.8H // ..................*............................... + smlal2 v17.4S, v26.8H, v8.8H // ........*......................................... + smlal v21.4S, v26.4H, v8.4H // ...............*.................................. + smlal v18.4S, v9.4H, v1.4H // ...................*.............................. + smlal2 v24.4S, v31.8H, v3.8H // ......................*........................... + smlal2 v17.4S, v9.8H, v3.8H // .................*................................ + smlal v21.4S, v9.4H, v3.4H // ....................*............................. + smlal v18.4S, v31.4H, v3.4H // .......................*.......................... + smlal2 v24.4S, v30.8H, v14.8H // ..........................*....................... + smlal2 v17.4S, v31.8H, v12.8H // .....................*............................ + smlal v21.4S, v31.4H, v12.4H // ........................*......................... + smlal v18.4S, v30.4H, v14.4H // ...........................*...................... + smlal2 v24.4S, v11.8H, v10.8H // ..............................*................... + smlal2 v17.4S, v30.8H, v10.8H // .........................*........................ + smlal v21.4S, v30.4H, v10.4H // ............................*..................... + smlal v18.4S, v11.4H, v10.4H // ...............................*.................. + zip2 v19.8H, v7.8H, v15.8H // ......................................*........... + smlal2 v17.4S, v11.8H, v22.8H // .............................*.................... + smlal v21.4S, v11.4H, v22.4H // ................................*................. + uzp1 v23.8H, v18.8H, v24.8H // .................................*................ + str q19, [x0, #16] // .........................................*........ + mul v19.8H, v23.8H, v2.8H // ..................................*............... + uzp1 v23.8H, v21.8H, v17.8H // .....................................*............ + str q5, [x0], #32 // .............................................*.... + mul v26.8H, v23.8H, v2.8H // .......................................*.......... + smlal v18.4S, v19.4H, v0.4H // ...................................*.............. + smlal2 v24.4S, v19.8H, v0.8H // ....................................*............. + smlal v21.4S, v26.4H, v0.4H // ...........................................*...... + smlal2 v17.4S, v26.8H, v0.8H // ..........................................*....... + uzp2 v13.8H, v18.8H, v24.8H // ........................................*......... + uzp2 v19.8H, v21.8H, v17.8H // ............................................*..... + zip1 v23.8H, v19.8H, v13.8H // ..............................................*... + zip2 v19.8H, v19.8H, v13.8H // ...............................................*.. + str q23, [x0], #32 // .................................................* + str q19, [x0, #-16] // ................................................*. + + // ----------------- new position ------------------> + // 0 25 + // |------------------------|------------------------ + // smlal2 v24.4S, v26.8H, v28.8H // ...*.............................................. + // uzp2 v4.8H, v4.8H, v3.8H // .......*.......................................... + // smull2 v13.4S, v31.8H, v19.8H // *................................................. + // uzp2 v1.8H, v29.8H, v25.8H // ...........*...................................... + // smlal2 v13.4S, v16.8H, v23.8H // ....*............................................. + // smull v18.4S, v31.4H, v19.4H // .....*............................................ + // smlal2 v13.4S, v20.8H, v28.8H // ............*..................................... + // smull v29.4S, v31.4H, v21.4H // ..*............................................... + // smlal2 v13.4S, v26.8H, v8.8H // ................*................................. + // smlal v29.4S, v16.4H, v19.4H // ......*........................................... + // smlal v18.4S, v16.4H, v23.4H // .........*........................................ + // smlal v29.4S, v20.4H, v27.4H // ..........*....................................... + // uzp1 v31.8H, v14.8H, v6.8H // ........*......................................... + // smlal v18.4S, v20.4H, v28.4H // .............*.................................... + // smlal v29.4S, v26.4H, v28.4H // ..............*................................... + // smlal v18.4S, v26.4H, v8.4H // .................*................................ + // uzp2 v26.8H, v14.8H, v6.8H // .*................................................ + // smlal2 v13.4S, v9.8H, v31.8H // ....................*............................. + // smlal2 v24.4S, v9.8H, v26.8H // ...............*.................................. + // smlal v29.4S, v9.4H, v26.4H // ..................*............................... + // smlal v18.4S, v9.4H, v31.4H // .....................*............................ + // smlal2 v13.4S, v4.8H, v12.8H // ........................*......................... + // smlal2 v24.4S, v4.8H, v31.8H // ...................*.............................. + // smlal v29.4S, v4.4H, v31.4H // ......................*........................... + // smlal v18.4S, v4.4H, v12.4H // .........................*........................ + // smlal2 v13.4S, v30.8H, v10.8H // ............................*..................... + // smlal2 v24.4S, v30.8H, v1.8H // .......................*.......................... + // smlal v29.4S, v30.4H, v1.4H // ..........................*....................... + // smlal v18.4S, v30.4H, v10.4H // .............................*.................... + // smlal2 v13.4S, v11.8H, v22.8H // ................................*................. + // smlal2 v24.4S, v11.8H, v10.8H // ...........................*...................... + // smlal v29.4S, v11.4H, v10.4H // ..............................*................... + // smlal v18.4S, v11.4H, v22.4H // .................................*................ + // uzp1 v31.8H, v29.8H, v24.8H // ..................................*............... + // mul v19.8H, v31.8H, v2.8H // ....................................*............. + // smlal v29.4S, v19.4H, v0.4H // ........................................*......... + // smlal2 v24.4S, v19.8H, v0.8H // .........................................*........ + // uzp1 v26.8H, v18.8H, v13.8H // .....................................*............ + // zip2 v14.8H, v7.8H, v15.8H // ...............................*.................. + // mul v23.8H, v26.8H, v2.8H // .......................................*.......... + // uzp2 v15.8H, v29.8H, v24.8H // ............................................*..... + // str q14, [x0, #16] // ...................................*.............. + // smlal2 v13.4S, v23.8H, v0.8H // ...........................................*...... + // smlal v18.4S, v23.4H, v0.4H // ..........................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // .............................................*.... + // str q5, [x0], #32 // ......................................*........... + // zip1 v5.8H, v7.8H, v15.8H // ..............................................*... + // zip2 v14.8H, v7.8H, v15.8H // ...............................................*.. + // str q14, [x0, #16] // .................................................* + // str q5, [x0], #32 // ................................................*. + + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_asm_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_asm_clean.S new file mode 100644 index 0000000000..722dc0f49e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_asm_clean.S @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/************************************************* + * Name: rej_uniform_asm_clean + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer of MLKEM_N + * 16-bit coefficients. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * Must be a multiple of 24. + * + * Returns number of sampled 16-bit integers (at most MLKEM_N). + **************************************************/ +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// We save the output on the stack first, and copy to the actual +// output buffer only in the end. This is because the main loop can overwrite +// by up to 62 bytes, which we account for here (we use 64 bytes for alignment). +#define STACK_SIZE (2*MLKEM_N + 64) +#define STACK_OFFSET_TMP_OUTPUT 0 + +.macro push_stack + sub sp, sp, #STACK_SIZE +.endm + +.macro pop_stack + add sp, sp, #STACK_SIZE +.endm + + /* Parameters */ + output .req x0 + buf .req x1 + buflen .req w2 + table_idx .req x3 + + len .req w4 + + /* Temporary output on the stack */ + output_tmp .req x7 + output_tmp_base .req x8 + + /* Number of coefficients sampled so far */ + count .req w9 + buf_consumed .req w10 + + /* Temporary registers */ + tmp .req w11 + final_copy_count .req w11 + + rec_idx_0 .req w12 + rec_idx_1 .req w13 + rec_idx_2 .req w14 + rec_idx_3 .req w15 + + ctr0 .req w12 + ctr1 .req w13 + ctr2 .req w14 + ctr3 .req w15 + + ctr01 .req ctr0 + ctr23 .req ctr2 + + /* Vector registers */ + + buf0 .req v0 + buf1 .req v1 + buf2 .req v2 + + tmp0 .req v4 + tmp1 .req v5 + tmp2 .req v6 + tmp3 .req v7 + + sign0 .req v4 + sign1 .req v5 + sign2 .req v6 + sign3 .req v7 + + val0 .req v16 + val0q .req q16 + val1 .req v17 + val1q .req q17 + val2 .req v18 + val2q .req q18 + val3 .req v19 + val3q .req q19 + + t0 .req s20 + t1 .req s21 + t2 .req s22 + t3 .req s23 + + table0 .req v24 + table0q .req q24 + table1 .req v25 + table1q .req q25 + table2 .req v26 + table2q .req q26 + table3 .req v27 + table3q .req q27 + + mlkem_q .req v30 + bits .req v31 + bits_q .req q31 + +.text +/* Literal pool */ +.p2align 4 +c_bit_table: + .short 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80 + +.align 4 +.global MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean) +MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean): + push_stack + + ldr bits_q, c_bit_table + movz tmp, #MLKEM_Q + dup mlkem_q.8h, tmp + + add output_tmp_base, sp, #STACK_OFFSET_TMP_OUTPUT + mov output_tmp, output_tmp_base + + mov count, #0 + mov len, #MLKEM_N + + cmp buflen, #48 + b.lo loop48_end + +loop48: + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + // First, we unpack the byte stream into a stream of signed + // coefficients, interpreting each consecutive 3 bytes as two + // signed 12-bit coefficients, presented as 16-bit integers. + // + // We handle 16 such triples a time, and use ld3 for the required + // de-interleaving of the byte stream. + sub buflen, buflen, #48 + ld3 {buf0.16b, buf1.16b, buf2.16b}, [buf], #48 + + // Unpack 16 triples of bytes into 16 pairs of 16-bit integers, + // represented as 4 vectors val0-val3. + zip1 tmp0.16b, buf0.16b, buf1.16b + zip2 tmp1.16b, buf0.16b, buf1.16b + zip1 tmp2.16b, buf1.16b, buf2.16b + zip2 tmp3.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + bic tmp1.8h, #0xf0, lsl 8 + ushr tmp2.8h, tmp2.8h, #4 + ushr tmp3.8h, tmp3.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp2.8h + zip2 val1.8h, tmp0.8h, tmp2.8h + zip1 val2.8h, tmp1.8h, tmp3.8h + zip2 val3.8h, tmp1.8h, tmp3.8h + + // At this point, val0-val3 are the signed integers to do rejection + // sampling on. For each of them, do the following: + // - Check which coefficients are within range, and represent the set + // of lane-indices of those coefficients as an 8-bit bitmap. + // - Move the respective lanes to the front of the vector. This is the + // most complex part, and is done by interpreting the 8-bit bitmap as + // an index into a lookup table giving the lane-table to be use for + // the `tbl` instruction. + // - Write the vector to the output buffer, but merely increase the output + // buffer pointer by the number of valid coefficients. + + // Set valid lanes to -1 (0b1...1) + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + cmhi sign2.8h, mlkem_q.8h, val2.8h + cmhi sign3.8h, mlkem_q.8h, val3.8h + + // If lane i is valid and has value -1, retain only i-th bit + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + and sign2.16b, sign2.16b, bits.16b + and sign3.16b, sign3.16b, bits.16b + + // Get 8-bit bitmap of valid lane indices by adding lanes + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + fmov rec_idx_2, t2 + fmov rec_idx_3, t3 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + ldr table2q, [table_idx, rec_idx_2, uxtw #4] + ldr table3q, [table_idx, rec_idx_3, uxtw #4] + + // Compute number of valid coefficients. Recall that at this + // point, lane i has value 2^i (hence popcount 1) if its coefficient + // is valid, and 0 otherwise. + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + cnt sign2.16b, sign2.16b + cnt sign3.16b, sign3.16b + + // Extract number of valid coefficients + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov ctr0, t0 + fmov ctr1, t1 + fmov ctr2, t2 + fmov ctr3, t3 + + // Move valid coefficients to the front + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + tbl val2.16b, {val2.16b}, table2.16b + tbl val3.16b, {val3.16b}, table3.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + str val2q, [output_tmp] + add output_tmp, output_tmp, ctr2, uxtw #1 + + str val3q, [output_tmp] + add output_tmp, output_tmp, ctr3, uxtw #1 + + add ctr01, ctr0, ctr1 + add ctr23, ctr2, ctr3 + add count, count, ctr01 + add count, count, ctr23 + + cmp buflen, #48 + b.hs loop48 +loop48_end: + + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + cmp buflen, #24 + b.lo memory_copy + + sub buflen, buflen, #24 + ld3 {buf0.8b, buf1.8b, buf2.8b}, [buf], #24 + + zip1 tmp0.16b, buf0.16b, buf1.16b + zip1 tmp1.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + ushr tmp1.8h, tmp1.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp1.8h + zip2 val1.8h, tmp0.8h, tmp1.8h + + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov ctr0, t0 + fmov ctr1, t1 + + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + add count, count, ctr0 + add count, count, ctr1 + +memory_copy: + // min = min(count,len) + cmp count, len + csel count, count, len, lo + + // Always copy MLKEM_N coefficients from the stack to the destination, + // even if not all of them may be valid. This simplifies the loop and + // allows us to stick to vectorized code. + mov final_copy_count, #0 + mov output_tmp, output_tmp_base +final_copy: + ldr val0q, [output_tmp], #64 + ldr val1q, [output_tmp, #-48] + ldr val2q, [output_tmp, #-32] + ldr val3q, [output_tmp, #-16] + str val0q, [output], #64 + str val1q, [output, #-48] + str val2q, [output, #-32] + str val3q, [output, #-16] + add final_copy_count, final_copy_count, #32 + cmp final_copy_count, #MLKEM_N + b.lt final_copy + + mov w0, count + b return + +return: + pop_stack + ret + +#endif /* defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_table.c new file mode 100644 index 0000000000..507660349d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/aarch64/src/rej_uniform_table.c @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 0 */, + 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 1 */, + 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 2 */, + 0, 1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 3 */, + 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 4 */, + 0, 1, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 5 */, + 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 6 */, + 0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 7 */, + 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 8 */, + 0, 1, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 9 */, + 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 10 */, + 0, 1, 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 11 */, + 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 12 */, + 0, 1, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 13 */, + 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 14 */, + 0, 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1 /* 15 */, + 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 16 */, + 0, 1, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 17 */, + 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 18 */, + 0, 1, 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 19 */, + 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 20 */, + 0, 1, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 21 */, + 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 22 */, + 0, 1, 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 23 */, + 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 24 */, + 0, 1, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 25 */, + 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 26 */, + 0, 1, 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 27 */, + 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 28 */, + 0, 1, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 29 */, + 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 30 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1 /* 31 */, + 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 32 */, + 0, 1, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 33 */, + 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 34 */, + 0, 1, 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 35 */, + 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 36 */, + 0, 1, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 37 */, + 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 38 */, + 0, 1, 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 39 */, + 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 40 */, + 0, 1, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 41 */, + 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 42 */, + 0, 1, 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 43 */, + 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 44 */, + 0, 1, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 45 */, + 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 46 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1 /* 47 */, + 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 48 */, + 0, 1, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 49 */, + 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 50 */, + 0, 1, 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 51 */, + 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 52 */, + 0, 1, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 53 */, + 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 54 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 55 */, + 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 56 */, + 0, 1, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 57 */, + 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 58 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 59 */, + 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 60 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 61 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 62 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1 /* 63 */, + 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 64 */, + 0, 1, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 65 */, + 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 66 */, + 0, 1, 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 67 */, + 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 68 */, + 0, 1, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 69 */, + 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 70 */, + 0, 1, 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 71 */, + 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 72 */, + 0, 1, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 73 */, + 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 74 */, + 0, 1, 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 75 */, + 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 76 */, + 0, 1, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 77 */, + 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 78 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1 /* 79 */, + 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 80 */, + 0, 1, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 81 */, + 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 82 */, + 0, 1, 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 83 */, + 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 84 */, + 0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 85 */, + 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 86 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 87 */, + 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 88 */, + 0, 1, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 89 */, + 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 90 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 91 */, + 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 92 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 93 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 94 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1 /* 95 */, + 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 96 */, + 0, 1, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 97 */, + 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 98 */, + 0, 1, 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 99 */, + 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 100 */, + 0, 1, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 101 */, + 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 102 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 103 */, + 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 104 */, + 0, 1, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 105 */, + 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 106 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 107 */, + 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 108 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 109 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 110 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1 /* 111 */, + 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 112 */, + 0, 1, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 113 */, + 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 114 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 115 */, + 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 116 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 117 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 118 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 119 */, + 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 120 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 121 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 122 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 123 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 124 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 125 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 126 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1 /* 127 */, + 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 128 */, + 0, 1, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 129 */, + 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 130 */, + 0, 1, 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 131 */, + 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 132 */, + 0, 1, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 133 */, + 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 134 */, + 0, 1, 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 135 */, + 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 136 */, + 0, 1, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 137 */, + 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 138 */, + 0, 1, 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 139 */, + 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 140 */, + 0, 1, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 141 */, + 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 142 */, + 0, 1, 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1 /* 143 */, + 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 144 */, + 0, 1, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 145 */, + 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 146 */, + 0, 1, 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 147 */, + 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 148 */, + 0, 1, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 149 */, + 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 150 */, + 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 151 */, + 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 152 */, + 0, 1, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 153 */, + 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 154 */, + 0, 1, 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 155 */, + 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 156 */, + 0, 1, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 157 */, + 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 158 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1 /* 159 */, + 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 160 */, + 0, 1, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 161 */, + 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 162 */, + 0, 1, 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 163 */, + 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 164 */, + 0, 1, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 165 */, + 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 166 */, + 0, 1, 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 167 */, + 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 168 */, + 0, 1, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 169 */, + 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 170 */, + 0, 1, 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 171 */, + 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 172 */, + 0, 1, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 173 */, + 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 174 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1 /* 175 */, + 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 176 */, + 0, 1, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 177 */, + 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 178 */, + 0, 1, 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 179 */, + 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 180 */, + 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 181 */, + 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 182 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 183 */, + 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 184 */, + 0, 1, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 185 */, + 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 186 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 187 */, + 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 188 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 189 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 190 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1 /* 191 */, + 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 192 */, + 0, 1, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 193 */, + 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 194 */, + 0, 1, 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 195 */, + 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 196 */, + 0, 1, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 197 */, + 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 198 */, + 0, 1, 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 199 */, + 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 200 */, + 0, 1, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 201 */, + 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 202 */, + 0, 1, 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 203 */, + 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 204 */, + 0, 1, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 205 */, + 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 206 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1 /* 207 */, + 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 208 */, + 0, 1, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 209 */, + 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 210 */, + 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 211 */, + 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 212 */, + 0, 1, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 213 */, + 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 214 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 215 */, + 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 216 */, + 0, 1, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 217 */, + 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 218 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 219 */, + 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 220 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 221 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 222 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1 /* 223 */, + 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 224 */, + 0, 1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 225 */, + 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 226 */, + 0, 1, 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 227 */, + 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 228 */, + 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 229 */, + 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 230 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 231 */, + 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 232 */, + 0, 1, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 233 */, + 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 234 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 235 */, + 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 236 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 237 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 238 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1 /* 239 */, + 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 240 */, + 0, 1, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 241 */, + 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 242 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 243 */, + 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 244 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 245 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 246 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 247 */, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 248 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 249 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 250 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 251 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 252 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 253 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 254 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 /* 255 */, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_aarch64_rej_uniform_table) +int empty_cu_aarch64_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_aarch64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_ref/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/README.md new file mode 100644 index 0000000000..2073425c3b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/README.md @@ -0,0 +1,4 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +This directory contains the native x86_64 arithmetic backend for ML-KEM provided by the official [AVX2 +implementation](https://github.com/pq-crystals/kyber/tree/main/avx2) of the Kyber team. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/default.h new file mode 100644 index 0000000000..592e8996dc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/default.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME X86_64_DEFAULT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "x86_64/src/default_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/align.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/align.h new file mode 100644 index 0000000000..42a02fe57c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/align.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/align.h + */ + +#ifndef ALIGN_H +#define ALIGN_H + +#include +#include + +#define ALIGNED_UINT8(N) \ + union \ + { \ + uint8_t coeffs[N]; \ + __m256i vec[(N + 31) / 32]; \ + } + +#define ALIGNED_INT16(N) \ + union \ + { \ + int16_t coeffs[N]; \ + __m256i vec[(N + 15) / 16]; \ + } + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/arith_native_x86_64.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/arith_native_x86_64.h new file mode 100644 index 0000000000..ce13e7911f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/arith_native_x86_64.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_X86_64_NATIVE_H +#define MLKEM_X86_64_NATIVE_H + +#include "common.h" + +#include +#include +#include "polyvec.h" +#include "consts.h" + +#define REJ_UNIFORM_AVX_NBLOCKS 3 /* See MLKEM_GEN_MATRIX_NBLOCKS */ +#define REJ_UNIFORM_AVX_BUFLEN \ + (3 * 168) /* REJ_UNIFORM_AVX_BUFLEN * SHAKE128_RATE */ + +#define rej_uniform_avx2 MLKEM_NAMESPACE(rej_uniform_avx2) +unsigned int rej_uniform_avx2(int16_t *r, const uint8_t *buf); + +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) +extern const uint8_t rej_uniform_table[256][8]; + +#define ntt_avx2 MLKEM_NAMESPACE(ntt_avx2) +void ntt_avx2(__m256i *r, const __m256i *qdata); + +#define invntt_avx2 MLKEM_NAMESPACE(invntt_avx2) +void invntt_avx2(__m256i *r, const __m256i *qdata); + +#define nttpack_avx2 MLKEM_NAMESPACE(nttpack_avx2) +void nttpack_avx2(__m256i *r, const __m256i *qdata); + +#define nttunpack_avx2 MLKEM_NAMESPACE(nttunpack_avx2) +void nttunpack_avx2(__m256i *r, const __m256i *qdata); + +#define reduce_avx2 MLKEM_NAMESPACE(reduce_avx2) +void reduce_avx2(__m256i *r, const __m256i *qdata); + +#define basemul_avx2 MLKEM_NAMESPACE(basemul_avx2) +void basemul_avx2(__m256i *r, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define polyvec_basemul_acc_montgomery_cached_avx2 \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_avx2) +void polyvec_basemul_acc_montgomery_cached_avx2( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); + +#define ntttobytes_avx2 MLKEM_NAMESPACE(ntttobytes_avx2) +void ntttobytes_avx2(uint8_t *r, const __m256i *a, const __m256i *qdata); + +#define nttfrombytes_avx2 MLKEM_NAMESPACE(nttfrombytes_avx2) +void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata); + +#define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2) +void tomont_avx2(__m256i *r, const __m256i *qdata); + +#endif /* MLKEM_X86_64_NATIVE_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/basemul.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.S similarity index 61% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/basemul.S rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.S index 36990639b2..b97840e702 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/basemul.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.S @@ -1,12 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" +/* Polynomials to be multiplied are denoted a+bX (rsi arg) and c+dX (rdx arg) */ .macro schoolbook off -vmovdqa _16XQINV*2(%rcx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQINV*2(%rcx),%ymm0 vmovdqa (64*\off+ 0)*2(%rsi),%ymm1 # a0 vmovdqa (64*\off+16)*2(%rsi),%ymm2 # b0 vmovdqa (64*\off+32)*2(%rsi),%ymm3 # a1 vmovdqa (64*\off+48)*2(%rsi),%ymm4 # b1 +/* Prepare Montgomery twists */ vpmullw %ymm0,%ymm1,%ymm9 # a0.lo vpmullw %ymm0,%ymm2,%ymm10 # b0.lo vpmullw %ymm0,%ymm3,%ymm11 # a1.lo @@ -15,6 +28,7 @@ vpmullw %ymm0,%ymm4,%ymm12 # b1.lo vmovdqa (64*\off+ 0)*2(%rdx),%ymm5 # c0 vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0 +/* Compute high-parts of monomials in (a0+b0*X)*(c0+d0*X) */ vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi @@ -23,6 +37,8 @@ vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1 vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1 +/* Compute high-parts of monomials in (a1+b1*X)*(c1+d1*X) */ +/* Don't yet accumulate nor reduce X^2 */ vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi @@ -30,17 +46,22 @@ vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi vmovdqa %ymm13,(%rsp) +/* Compute low-parts of monomials in (a0+b0*X)*(c0+d0*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo +/* Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo -vmovdqa _16XQ*2(%rcx),%ymm8 +/* Compute 2nd high multiplication in Montgomery multiplication */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rcx),%ymm8 vpmulhw %ymm8,%ymm13,%ymm13 vpmulhw %ymm8,%ymm9,%ymm9 vpmulhw %ymm8,%ymm5,%ymm5 @@ -50,6 +71,7 @@ vpmulhw %ymm8,%ymm11,%ymm11 vpmulhw %ymm8,%ymm7,%ymm7 vpmulhw %ymm8,%ymm12,%ymm12 +/* Finish Montgomery multiplications */ vpsubw (%rsp),%ymm13,%ymm13 # -a0c0 vpsubw %ymm9,%ymm1,%ymm9 # a0d0 vpsubw %ymm5,%ymm14,%ymm5 # b0c0 @@ -60,6 +82,10 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1 vpsubw %ymm7,%ymm0,%ymm7 # b1c1 vpsubw %ymm12,%ymm4,%ymm12 # b1d1 +/* b0*d0 and b1*d1 need twisting by a twiddle, accounting + * for X^2=zeta in F_q[X]/(X^2-zeta). + * + * TODO: This could be precomputed in the mulcache */ vmovdqa (%r9),%ymm0 vmovdqa 32(%r9),%ymm1 vpmullw %ymm0,%ymm10,%ymm2 @@ -76,6 +102,9 @@ vpaddw %ymm7,%ymm11,%ymm11 vpsubw %ymm13,%ymm10,%ymm13 vpsubw %ymm12,%ymm6,%ymm6 +/* Bounds: Since we are multiplying with signed canonical twiddles, + * each Montgomery multiplication has absolute value < q, + * and hence the coefficients of the output have absolute value < 2q. */ vmovdqa %ymm13,(64*\off+ 0)*2(%rdi) vmovdqa %ymm9,(64*\off+16)*2(%rdi) vmovdqa %ymm6,(64*\off+32)*2(%rdi) @@ -83,13 +112,13 @@ vmovdqa %ymm11,(64*\off+48)*2(%rdi) .endm .text -.global cdecl(basemul_avx) -cdecl(basemul_avx): +.global MLKEM_ASM_NAMESPACE(basemul_avx2) +MLKEM_ASM_NAMESPACE(basemul_avx2): mov %rsp,%r8 and $-32,%rsp sub $32,%rsp -lea (_ZETAS_EXP+176)*2(%rcx),%r9 +lea (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+176)*2(%rcx),%r9 schoolbook 0 add $32*2,%r9 @@ -103,3 +132,5 @@ schoolbook 3 mov %r8,%rsp ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.c new file mode 100644 index 0000000000..5f9ae99c80 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/basemul.c @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "poly.h" +#include "polyvec.h" + +#include "arith_native_x86_64.h" +#include "consts.h" + +static void poly_basemul_montgomery_avx2(poly *r, const poly *a, const poly *b) +{ + basemul_avx2((__m256i *)r->coeffs, (const __m256i *)a->coeffs, + (const __m256i *)b->coeffs, qdata.vec); +} + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ +static void poly_add_avx2(poly *r, const poly *a, const poly *b) +{ + unsigned i; + __m256i f0, f1; + + for (i = 0; i < MLKEM_N; i += 16) + { + f0 = _mm256_load_si256((const __m256i *)&a->coeffs[i]); + f1 = _mm256_load_si256((const __m256i *)&b->coeffs[i]); + f0 = _mm256_add_epi16(f0, f1); + _mm256_store_si256((__m256i *)&r->coeffs[i], f0); + } +} + +void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + /* TODO: Use mulcache for AVX2. So far, it is unused. */ + ((void)b_cache); + + /* Coefficient-wise bound of each basemul is 2q. + * Since we are accumulating at most 4 times, the + * overall bound is 8q < INT16_MAX. */ + poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_avx2(&t, &a->vec[i], &b->vec[i]); + poly_add_avx2(r, r, &t); + } +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy constant to keep compiler happy despite empty CU */ + +#define empty_cu_avx2_basemul MLKEM_NAMESPACE(empty_cu_avx2_basemul) +int empty_cu_avx2_basemul; + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.c new file mode 100644 index 0000000000..86a0835efd --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.c @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.c + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "align.h" +#include "consts.h" + +#define Q MLKEM_Q +#define MONT -1044 /* 2^16 mod q */ +#define QINV -3327 /* q^-1 mod 2^16 */ +#define V 20159 /* floor(2^26/q + 0.5) */ +#define FHI 1441 /* mont^2/128 */ +#define FLO -10079 /* qinv*FHI */ +#define MONTSQHI 1353 /* mont^2 */ +#define MONTSQLO 20553 /* qinv*MONTSQHI */ +#define MASK 4095 +#define SHIFT 32 + +const qdata_t qdata = {{ +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, + +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, + +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 + V, V, V, V, V, V, + V, V, V, V, V, V, + V, V, V, V, + +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 + 3854, 3340, 2826, 2312, 1798, 1284, + 770, 256, 3854, 3340, 2826, 2312, + 1798, 1284, 770, 256, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 + 7, 0, 6, 0, 5, 0, + 4, 0, 3, 0, 2, 0, + 1, 0, 0, 0, + +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#include "x86_64_zetas.i" + +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT}}; + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_consts MLKEM_NAMESPACE(empty_cu_consts) +int empty_cu_consts; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.h new file mode 100644 index 0000000000..00c415952e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/consts.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.h + */ + +#ifndef CONSTS_H +#define CONSTS_H + +#include "common.h" + +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + +/* The C ABI on MacOS exports all symbols with a leading + * underscore. This means that any symbols we refer to from + * C files (functions) can't be found, and all symbols we + * refer to from ASM also can't be found. + * + * This define helps us get around this + */ + +#ifndef __ASSEMBLER__ +#include "align.h" +typedef ALIGNED_INT16(640) qdata_t; +#define qdata MLKEM_NAMESPACE(qdata) +extern const qdata_t qdata; +#endif + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/default_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/default_impl.h new file mode 100644 index 0000000000..66de8c85f3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/default_impl.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include + +#include "poly.h" +#include "polyvec.h" +#include "arith_native_x86_64.h" + +#define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + +#define MLKEM_USE_NATIVE_REJ_UNIFORM +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_POLY_FROMBYTES + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +#define NTT_BOUND_NATIVE (8 * MLKEM_Q) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +{ + nttunpack_avx2((__m256i *)(data->coeffs), qdata.vec); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + /* AVX2 implementation assumes specific buffer lengths */ + if (len != MLKEM_N || buflen != REJ_UNIFORM_AVX_BUFLEN) + { + return -1; + } + + return (int)rej_uniform_avx2(r, buf); +} + +static INLINE void ntt_native(poly *data) +{ + ntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void intt_native(poly *data) +{ + invntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void poly_reduce_native(poly *data) +{ + reduce_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_tomont_native(poly *data) +{ + tomont_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + /* AVX2 backend does not use mulcache */ + ((void)y); + ((void)x); +} + +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_avx2(r, a, b, b_cache); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + ntttobytes_avx2(r, (const __m256i *)a->coeffs, qdata.vec); +} + +static INLINE void poly_frombytes_native(poly *r, + const uint8_t a[MLKEM_POLYBYTES]) +{ + nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.S similarity index 50% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.S rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.S index 3bb1ebd3d8..134bd4f710 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.S @@ -1,8 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation based on Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +// Changes: +// - Add call to csub in reduce128_avx to produce outputs +// in [0,1,...,q-1] rather than [0,1,...,q], matching the +// semantics of poly_reduce(). + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) #include "consts.h" -.include "fq.inc" + +#include "fq.inc" .text -reduce128_avx: +reduce128_avx2: #load vmovdqa (%rdi),%ymm2 vmovdqa 32(%rdi),%ymm3 @@ -22,6 +39,15 @@ red16 7 red16 8 red16 9 +csubq 2 +csubq 3 +csubq 4 +csubq 5 +csubq 6 +csubq 7 +csubq 8 +csubq 9 + #store vmovdqa %ymm2,(%rdi) vmovdqa %ymm3,32(%rdi) @@ -34,17 +60,18 @@ vmovdqa %ymm9,224(%rdi) ret -.global cdecl(reduce_avx) -cdecl(reduce_avx): +.global MLKEM_ASM_NAMESPACE(reduce_avx2) +MLKEM_ASM_NAMESPACE(reduce_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XV*2(%rsi),%ymm1 -call reduce128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XV*2(%rsi),%ymm1 +call reduce128_avx2 add $256,%rdi -call reduce128_avx +call reduce128_avx2 ret -tomont128_avx: + +tomont128_avx2: #load vmovdqa (%rdi),%ymm3 vmovdqa 32(%rdi),%ymm4 @@ -76,13 +103,15 @@ vmovdqa %ymm10,224(%rdi) ret -.global cdecl(tomont_avx) -cdecl(tomont_avx): +.global MLKEM_ASM_NAMESPACE(tomont_avx2) +MLKEM_ASM_NAMESPACE(tomont_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XMONTSQLO*2(%rsi),%ymm1 -vmovdqa _16XMONTSQHI*2(%rsi),%ymm2 -call tomont128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO*2(%rsi),%ymm1 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI*2(%rsi),%ymm2 +call tomont128_avx2 add $256,%rdi -call tomont128_avx +call tomont128_avx2 ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.inc b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.inc similarity index 67% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.inc index 4b7afc3118..76ec7a3b9e 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/fq.inc @@ -1,3 +1,13 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + .macro red16 r,rs=0,x=12 vpmulhw %ymm1,%ymm\r,%ymm\x .if \rs @@ -22,6 +32,8 @@ vpand %ymm0,%ymm\x,%ymm\x vpaddw %ymm\x,%ymm\r,%ymm\r .endm +/* Montgomery multiplication between b and ah, + * with Montgomery twist of ah in al. */ .macro fqmulprecomp al,ah,b,x=12 vpmullw %ymm\al,%ymm\b,%ymm\x vpmulhw %ymm\ah,%ymm\b,%ymm\b diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/intt.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/intt.S new file mode 100644 index 0000000000..6b1d78ef26 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/intt.S @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* Implementation based on Kyber repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + * + * Changes to placement of modular reductions have + * been made to simplify reasoning of non-overflow */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" +#include "fq.inc" + +/* Compute four GS butterflies between rh{0,1,2,3} and rl{0,1,2,3}. + * Butterflies 0,1 use root zh0 and twisted root zl0, and butterflies + * 2,3 use root zh1 and twisted root zl1 + * Results are again in rl{0-3} and rh{0-3} */ +.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 +vpsubw %ymm\rl0,%ymm\rh0,%ymm12 /* ymm12 = rh0 - rl0 */ +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 /* rl0 = rh0 + rl0 */ +vpsubw %ymm\rl1,%ymm\rh1,%ymm13 /* ymm13 = rh1 - rl1 */ + +vpmullw %ymm\zl0,%ymm12,%ymm\rh0 /* rh0 = (rh0 - rl0) * root0_twisted */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 /* rl1 = rh1 + rh1 */ +vpsubw %ymm\rl2,%ymm\rh2,%ymm14 /* ymm14 = rh2 - rl2 */ + +vpmullw %ymm\zl0,%ymm13,%ymm\rh1 /* rh1 = (rh1 - rl1) * root0_twisted */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 /* rl2 = rh2 + rl2 */ +vpsubw %ymm\rl3,%ymm\rh3,%ymm15 /* ymm15 = rh3 - rl3 */ + +vpmullw %ymm\zl1,%ymm14,%ymm\rh2 /* rh2 = (rh2 - rl2) * root1_twisted */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 /* rl3 = rh3 + rl3 */ +vpmullw %ymm\zl1,%ymm15,%ymm\rh3 /* rh3 = (rh3 - rl3) * root1_twisted */ + +vpmulhw %ymm\zh0,%ymm12,%ymm12 /* ymm12 = (rh0 - rl0) * root0 */ +vpmulhw %ymm\zh0,%ymm13,%ymm13 /* ymm13 = (rh1 - rl1) * root0 */ + +vpmulhw %ymm\zh1,%ymm14,%ymm14 /* ymm14 = (rh2 - rl2) * root1 */ +vpmulhw %ymm\zh1,%ymm15,%ymm15 /* ymm15 = (rh3 - rl3) * root1 */ + +vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 /* rh0 = Q * [(rh0 - rl0) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 /* rh1 = Q * [(rh1 - rl1) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 /* rh2 = Q * [(rh2 - rl2) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 /* rh3 = Q * [(rh3 - rl3) * root0_twisted] */ + +vpsubw %ymm\rh0,%ymm12,%ymm\rh0 /* rh0 = montmul(rh0-rl0, root0) */ +vpsubw %ymm\rh1,%ymm13,%ymm\rh1 /* rh1 = montmul(rh1-rl1, root0) */ +vpsubw %ymm\rh2,%ymm14,%ymm\rh2 /* rh2 = montmul(rh2-rl2, root0) */ +vpsubw %ymm\rh3,%ymm15,%ymm\rh3 /* rh3 = montmul(rh3-rl3, root0) */ +.endm + +.macro intt_levels0t5 off +/* level 0 */ +/* no bounds assumptions */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFLO*2(%rsi),%ymm2 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFHI*2(%rsi),%ymm3 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +fqmulprecomp 2,3,4 +fqmulprecomp 2,3,6 +fqmulprecomp 2,3,5 +fqmulprecomp 2,3,7 + +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 + +fqmulprecomp 2,3,8 +fqmulprecomp 2,3,10 +fqmulprecomp 2,3,9 +fqmulprecomp 2,3,11 + +/* bounds: coefficients < q */ + +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm12 +vpshufb %ymm12,%ymm15,%ymm15 +vpshufb %ymm12,%ymm1,%ymm1 +vpshufb %ymm12,%ymm2,%ymm2 +vpshufb %ymm12,%ymm3,%ymm3 + +butterfly 4,5,8,9,6,7,10,11,15,1,2,3 + +/* Montgmoery multiplication with a signed canonical twiddle + * always has absolute value < q. This is used henceforth to + * normalize the absolute bounds on the second half inputs + * to the current butterfly + * + * 4,5,8,9 abs bound < 2q; 6,7,10,11 abs bound < q */ + +/* level 1 */ +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm1 +vpshufb %ymm1,%ymm2,%ymm2 +vpshufb %ymm1,%ymm3,%ymm3 + +butterfly 4,5,6,7,8,9,10,11,2,2,3,3 + +/* For 8,9,10,11, it is sufficient to use the bound INT16_MAX). */ +red16 7 +/* global abs bound < 4q */ + +vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.macro intt_level6 off +/* level 6 */ +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm2 + +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm3 + +butterfly 4,5,6,7,8,9,10,11 +/* global abs bound < 8q */ + +/* REF-CHANGE: The official AVX2 implementation has a `red16 4` for `off=0`. + * We don't need this because of the earlier red16 which ensures an 8q bound */ + +vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(invntt_avx2) +MLKEM_ASM_NAMESPACE(invntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +intt_levels0t5 0 +intt_levels0t5 1 + +intt_level6 0 +intt_level6 1 +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/ntt.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/ntt.S new file mode 100644 index 0000000000..e8bf7894b4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/ntt.S @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" + +/* Compute steps 1,2 / 3 of Montgomery multiplication */ +.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 +vpmullw %ymm\zl0,%ymm\rh0,%ymm12 +vpmullw %ymm\zl0,%ymm\rh1,%ymm13 + +vpmullw %ymm\zl1,%ymm\rh2,%ymm14 +vpmullw %ymm\zl1,%ymm\rh3,%ymm15 + +vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 +vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 + +vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 +vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 +.endm + +/* Compute step 3 / 3 of Montgomery multiplication */ +/* Multiply-high is signed; outputs are bound by 2^15 * q in abs value */ +.macro reduce +vpmulhw %ymm0,%ymm12,%ymm12 +vpmulhw %ymm0,%ymm13,%ymm13 + +vpmulhw %ymm0,%ymm14,%ymm14 +vpmulhw %ymm0,%ymm15,%ymm15 +.endm + +/* Finish Montgomery multiplication and compute add/sub steps in NTT butterfly + * + * At this point, the two high-products of 4 ongoing Montgomery multiplications + * are in %ymm{12,13,14,15} and %ymm{rh{0,1,2,3}}, respectively. + * The NTT coefficients that the results of the Montgomery multiplications should + * be add/sub-ed with, are in %ymm{rl{0,1,2,3}}. + * + * What's interesting, here, is that rather than completing the Montgomery + * multiplications by computing `%ymm{12+i} + %ymm{rh{i}}`, and then add/sub'ing + * the result into %ymm{rl{0,1,2,3}}, we add/sub both `%ymm{12+i}` and + * %ymm{rh{i}} to %ymm{rl{0,1,2,3}}, and then add the results. + * + * Functionally, though, this is still a signed Montgomery multiplication + * followed by an add/sub. + * + * Since the result of the Montgomery multiplication is bounded + * by q in absolute value, the coefficients overall grow by not + * more than q in absolute value per layer. */ +.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln /* rln = rl0 + rh0 */ +vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 /* rh0 = rl0 - rh0 */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 /* rl0 = rl1 + rh1 */ +vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 /* rh1 = rl1 - rh1 */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 /* rl1 = rl2 + rh2 */ +vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 /* rh2 = rl2 - rh2 */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 /* rl2 = rl3 + rh3 */ +vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 /* rh3 = rl3 - rh3 */ + +vpsubw %ymm12,%ymm\rln,%ymm\rln /* rln = rh0 + rl0 - ymm12 = rl0 + (rh0 - ymm12) */ +vpaddw %ymm12,%ymm\rh0,%ymm\rh0 /* rh0 = rl0 - rh0 + ymm12 = rl0 - (rh0 - ymm12) */ +vpsubw %ymm13,%ymm\rl0,%ymm\rl0 /* rl0 = rl1 + rh1 - ymm13 = rl1 + (rh1 - ymm13) */ +vpaddw %ymm13,%ymm\rh1,%ymm\rh1 /* rh1 = rl1 - rh1 + ymm13 = rl1 - (rh1 - ymm13) */ +vpsubw %ymm14,%ymm\rl1,%ymm\rl1 /* rl1 = rh2 + rl2 - ymm14 = rl2 + (rh2 - ymm14) */ +vpaddw %ymm14,%ymm\rh2,%ymm\rh2 /* rh2 = rl2 - rh2 + ymm14 = rl2 - (rh2 - ymm14) */ +vpsubw %ymm15,%ymm\rl2,%ymm\rl2 /* rl2 = rh3 + rl3 - ymm15 = rl3 + (rh3 - ymm15) */ +vpaddw %ymm15,%ymm\rh3,%ymm\rh3 /* rh3 = rl3 - rh3 + ymm15 = rl3 - (rh3 - ymm15) */ +.endm + +.macro level0 off +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm15 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.macro levels1t6 off +/* level 1 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +/* level 2 */ +shuffle8 5,10,7,10 +shuffle8 6,11,5,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 + +mul 7,10,5,11 + +shuffle8 3,8,6,8 +shuffle8 4,9,3,9 + +reduce +update 4,6,8,3,9,7,10,5,11 + +/* level 3 */ +shuffle4 8,5,9,5 +shuffle4 3,11,8,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 + +mul 9,5,8,11 + +shuffle4 4,7,3,7 +shuffle4 6,10,4,10 + +reduce +update 6,3,7,4,10,9,5,8,11 + +/* level 4 */ +shuffle2 7,8,10,8 +shuffle2 4,11,7,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 + +mul 10,8,7,11 + +shuffle2 6,9,4,9 +shuffle2 3,5,6,5 + +reduce +update 3,4,9,6,5,10,8,7,11 + +/* level 5 */ +shuffle1 9,7,5,7 +shuffle1 6,11,9,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 + +mul 5,7,9,11 + +shuffle1 3,10,6,10 +shuffle1 4,8,3,8 + +reduce +update 4,6,10,3,8,5,7,9,11 + +/* level 6 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 + +mul 10,3,9,11,14,15,8,2 + +reduce +update 8,4,6,5,7,10,3,9,11 + +vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(ntt_avx2) +MLKEM_ASM_NAMESPACE(ntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +level0 0 +level0 1 + +levels1t6 0 +levels1t6 1 + +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_avx2.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_avx2.c new file mode 100644 index 0000000000..54037a0df9 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_avx2.c @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include +#include +#include "arith_native_x86_64.h" +#include "consts.h" + +unsigned int rej_uniform_avx2(int16_t *RESTRICT r, const uint8_t *buf) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + uint32_t good; + const __m256i bound = + _mm256_load_si256(&qdata.vec[AVX2_BACKEND_DATA_OFFSET_16XQ / 16]); + const __m256i ones = _mm256_set1_epi8(1); + const __m256i mask = _mm256_set1_epi16(0xFFF); + const __m256i idx8 = + _mm256_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4, + 11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0); + __m256i f0, f1, g0, g1, g2, g3; + __m128i f, t, pilo, pihi; + + ctr = pos = 0; + while (ctr <= MLKEM_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 48) + { + f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); + /* Don't load from offset 24, as this would over-read the buffer */ + f1 = _mm256_loadu_si256((__m256i *)&buf[pos + 16]); + f0 = _mm256_permute4x64_epi64(f0, 0x94 /* 0b10010100 ~= (2,1,1,0) */); + f1 = _mm256_permute4x64_epi64(f1, 0xe9 /* 0x11101001 ~= (3,2,2,1) */); + f0 = _mm256_shuffle_epi8(f0, idx8); + f1 = _mm256_shuffle_epi8(f1, idx8); + g0 = _mm256_srli_epi16(f0, 4); + g1 = _mm256_srli_epi16(f1, 4); + f0 = _mm256_blend_epi16(f0, g0, 0xAA); + f1 = _mm256_blend_epi16(f1, g1, 0xAA); + f0 = _mm256_and_si256(f0, mask); + f1 = _mm256_and_si256(f1, mask); + pos += 48; + + g0 = _mm256_cmpgt_epi16(bound, f0); + g1 = _mm256_cmpgt_epi16(bound, f1); + + g0 = _mm256_packs_epi16(g0, g1); + good = _mm256_movemask_epi8(g0); + + g0 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 0) & 0xFF])); + g1 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 8) & 0xFF])); + g0 = _mm256_inserti128_si256( + g0, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 16) & 0xFF]), + 1); + g1 = _mm256_inserti128_si256( + g1, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 24) & 0xFF]), + 1); + + g2 = _mm256_add_epi8(g0, ones); + g3 = _mm256_add_epi8(g1, ones); + g0 = _mm256_unpacklo_epi8(g0, g2); + g1 = _mm256_unpacklo_epi8(g1, g3); + + f0 = _mm256_shuffle_epi8(f0, g0); + f1 = _mm256_shuffle_epi8(f1, g1); + + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); + ctr += _mm_popcnt_u32((good >> 0) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); + ctr += _mm_popcnt_u32((good >> 16) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); + ctr += _mm_popcnt_u32((good >> 8) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); + ctr += _mm_popcnt_u32((good >> 24) & 0xFF); + } + + while (ctr <= MLKEM_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 24) + { + f = _mm_loadu_si128((__m128i *)&buf[pos]); + f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); + t = _mm_srli_epi16(f, 4); + f = _mm_blend_epi16(f, t, 0xAA); + f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); + pos += 12; + + t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); + good = _mm_movemask_epi8(t); + + good = _pext_u32(good, 0x5555); + pilo = _mm_loadl_epi64((__m128i *)&rej_uniform_table[good]); + + pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); + pilo = _mm_unpacklo_epi8(pilo, pihi); + f = _mm_shuffle_epi8(f, pilo); + _mm_storeu_si128((__m128i *)&r[ctr], f); + ctr += _mm_popcnt_u32(good); + } + + while (ctr < MLKEM_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)); + pos += 3; + + if (val0 < MLKEM_Q) + r[ctr++] = val0; + if (val1 < MLKEM_Q && ctr < MLKEM_N) + r[ctr++] = val1; + } + + return ctr; +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_rej_uniform_avx2 MLKEM_NAMESPACE(empty_cu_rej_uniform_avx2) +int empty_cu_rej_uniform_avx2; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_table.c new file mode 100644 index 0000000000..9bbc47146f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/rej_uniform_table.c @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include "arith_native_x86_64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[256][8] = { + {-1, -1, -1, -1, -1, -1, -1, -1}, {0, -1, -1, -1, -1, -1, -1, -1}, + {2, -1, -1, -1, -1, -1, -1, -1}, {0, 2, -1, -1, -1, -1, -1, -1}, + {4, -1, -1, -1, -1, -1, -1, -1}, {0, 4, -1, -1, -1, -1, -1, -1}, + {2, 4, -1, -1, -1, -1, -1, -1}, {0, 2, 4, -1, -1, -1, -1, -1}, + {6, -1, -1, -1, -1, -1, -1, -1}, {0, 6, -1, -1, -1, -1, -1, -1}, + {2, 6, -1, -1, -1, -1, -1, -1}, {0, 2, 6, -1, -1, -1, -1, -1}, + {4, 6, -1, -1, -1, -1, -1, -1}, {0, 4, 6, -1, -1, -1, -1, -1}, + {2, 4, 6, -1, -1, -1, -1, -1}, {0, 2, 4, 6, -1, -1, -1, -1}, + {8, -1, -1, -1, -1, -1, -1, -1}, {0, 8, -1, -1, -1, -1, -1, -1}, + {2, 8, -1, -1, -1, -1, -1, -1}, {0, 2, 8, -1, -1, -1, -1, -1}, + {4, 8, -1, -1, -1, -1, -1, -1}, {0, 4, 8, -1, -1, -1, -1, -1}, + {2, 4, 8, -1, -1, -1, -1, -1}, {0, 2, 4, 8, -1, -1, -1, -1}, + {6, 8, -1, -1, -1, -1, -1, -1}, {0, 6, 8, -1, -1, -1, -1, -1}, + {2, 6, 8, -1, -1, -1, -1, -1}, {0, 2, 6, 8, -1, -1, -1, -1}, + {4, 6, 8, -1, -1, -1, -1, -1}, {0, 4, 6, 8, -1, -1, -1, -1}, + {2, 4, 6, 8, -1, -1, -1, -1}, {0, 2, 4, 6, 8, -1, -1, -1}, + {10, -1, -1, -1, -1, -1, -1, -1}, {0, 10, -1, -1, -1, -1, -1, -1}, + {2, 10, -1, -1, -1, -1, -1, -1}, {0, 2, 10, -1, -1, -1, -1, -1}, + {4, 10, -1, -1, -1, -1, -1, -1}, {0, 4, 10, -1, -1, -1, -1, -1}, + {2, 4, 10, -1, -1, -1, -1, -1}, {0, 2, 4, 10, -1, -1, -1, -1}, + {6, 10, -1, -1, -1, -1, -1, -1}, {0, 6, 10, -1, -1, -1, -1, -1}, + {2, 6, 10, -1, -1, -1, -1, -1}, {0, 2, 6, 10, -1, -1, -1, -1}, + {4, 6, 10, -1, -1, -1, -1, -1}, {0, 4, 6, 10, -1, -1, -1, -1}, + {2, 4, 6, 10, -1, -1, -1, -1}, {0, 2, 4, 6, 10, -1, -1, -1}, + {8, 10, -1, -1, -1, -1, -1, -1}, {0, 8, 10, -1, -1, -1, -1, -1}, + {2, 8, 10, -1, -1, -1, -1, -1}, {0, 2, 8, 10, -1, -1, -1, -1}, + {4, 8, 10, -1, -1, -1, -1, -1}, {0, 4, 8, 10, -1, -1, -1, -1}, + {2, 4, 8, 10, -1, -1, -1, -1}, {0, 2, 4, 8, 10, -1, -1, -1}, + {6, 8, 10, -1, -1, -1, -1, -1}, {0, 6, 8, 10, -1, -1, -1, -1}, + {2, 6, 8, 10, -1, -1, -1, -1}, {0, 2, 6, 8, 10, -1, -1, -1}, + {4, 6, 8, 10, -1, -1, -1, -1}, {0, 4, 6, 8, 10, -1, -1, -1}, + {2, 4, 6, 8, 10, -1, -1, -1}, {0, 2, 4, 6, 8, 10, -1, -1}, + {12, -1, -1, -1, -1, -1, -1, -1}, {0, 12, -1, -1, -1, -1, -1, -1}, + {2, 12, -1, -1, -1, -1, -1, -1}, {0, 2, 12, -1, -1, -1, -1, -1}, + {4, 12, -1, -1, -1, -1, -1, -1}, {0, 4, 12, -1, -1, -1, -1, -1}, + {2, 4, 12, -1, -1, -1, -1, -1}, {0, 2, 4, 12, -1, -1, -1, -1}, + {6, 12, -1, -1, -1, -1, -1, -1}, {0, 6, 12, -1, -1, -1, -1, -1}, + {2, 6, 12, -1, -1, -1, -1, -1}, {0, 2, 6, 12, -1, -1, -1, -1}, + {4, 6, 12, -1, -1, -1, -1, -1}, {0, 4, 6, 12, -1, -1, -1, -1}, + {2, 4, 6, 12, -1, -1, -1, -1}, {0, 2, 4, 6, 12, -1, -1, -1}, + {8, 12, -1, -1, -1, -1, -1, -1}, {0, 8, 12, -1, -1, -1, -1, -1}, + {2, 8, 12, -1, -1, -1, -1, -1}, {0, 2, 8, 12, -1, -1, -1, -1}, + {4, 8, 12, -1, -1, -1, -1, -1}, {0, 4, 8, 12, -1, -1, -1, -1}, + {2, 4, 8, 12, -1, -1, -1, -1}, {0, 2, 4, 8, 12, -1, -1, -1}, + {6, 8, 12, -1, -1, -1, -1, -1}, {0, 6, 8, 12, -1, -1, -1, -1}, + {2, 6, 8, 12, -1, -1, -1, -1}, {0, 2, 6, 8, 12, -1, -1, -1}, + {4, 6, 8, 12, -1, -1, -1, -1}, {0, 4, 6, 8, 12, -1, -1, -1}, + {2, 4, 6, 8, 12, -1, -1, -1}, {0, 2, 4, 6, 8, 12, -1, -1}, + {10, 12, -1, -1, -1, -1, -1, -1}, {0, 10, 12, -1, -1, -1, -1, -1}, + {2, 10, 12, -1, -1, -1, -1, -1}, {0, 2, 10, 12, -1, -1, -1, -1}, + {4, 10, 12, -1, -1, -1, -1, -1}, {0, 4, 10, 12, -1, -1, -1, -1}, + {2, 4, 10, 12, -1, -1, -1, -1}, {0, 2, 4, 10, 12, -1, -1, -1}, + {6, 10, 12, -1, -1, -1, -1, -1}, {0, 6, 10, 12, -1, -1, -1, -1}, + {2, 6, 10, 12, -1, -1, -1, -1}, {0, 2, 6, 10, 12, -1, -1, -1}, + {4, 6, 10, 12, -1, -1, -1, -1}, {0, 4, 6, 10, 12, -1, -1, -1}, + {2, 4, 6, 10, 12, -1, -1, -1}, {0, 2, 4, 6, 10, 12, -1, -1}, + {8, 10, 12, -1, -1, -1, -1, -1}, {0, 8, 10, 12, -1, -1, -1, -1}, + {2, 8, 10, 12, -1, -1, -1, -1}, {0, 2, 8, 10, 12, -1, -1, -1}, + {4, 8, 10, 12, -1, -1, -1, -1}, {0, 4, 8, 10, 12, -1, -1, -1}, + {2, 4, 8, 10, 12, -1, -1, -1}, {0, 2, 4, 8, 10, 12, -1, -1}, + {6, 8, 10, 12, -1, -1, -1, -1}, {0, 6, 8, 10, 12, -1, -1, -1}, + {2, 6, 8, 10, 12, -1, -1, -1}, {0, 2, 6, 8, 10, 12, -1, -1}, + {4, 6, 8, 10, 12, -1, -1, -1}, {0, 4, 6, 8, 10, 12, -1, -1}, + {2, 4, 6, 8, 10, 12, -1, -1}, {0, 2, 4, 6, 8, 10, 12, -1}, + {14, -1, -1, -1, -1, -1, -1, -1}, {0, 14, -1, -1, -1, -1, -1, -1}, + {2, 14, -1, -1, -1, -1, -1, -1}, {0, 2, 14, -1, -1, -1, -1, -1}, + {4, 14, -1, -1, -1, -1, -1, -1}, {0, 4, 14, -1, -1, -1, -1, -1}, + {2, 4, 14, -1, -1, -1, -1, -1}, {0, 2, 4, 14, -1, -1, -1, -1}, + {6, 14, -1, -1, -1, -1, -1, -1}, {0, 6, 14, -1, -1, -1, -1, -1}, + {2, 6, 14, -1, -1, -1, -1, -1}, {0, 2, 6, 14, -1, -1, -1, -1}, + {4, 6, 14, -1, -1, -1, -1, -1}, {0, 4, 6, 14, -1, -1, -1, -1}, + {2, 4, 6, 14, -1, -1, -1, -1}, {0, 2, 4, 6, 14, -1, -1, -1}, + {8, 14, -1, -1, -1, -1, -1, -1}, {0, 8, 14, -1, -1, -1, -1, -1}, + {2, 8, 14, -1, -1, -1, -1, -1}, {0, 2, 8, 14, -1, -1, -1, -1}, + {4, 8, 14, -1, -1, -1, -1, -1}, {0, 4, 8, 14, -1, -1, -1, -1}, + {2, 4, 8, 14, -1, -1, -1, -1}, {0, 2, 4, 8, 14, -1, -1, -1}, + {6, 8, 14, -1, -1, -1, -1, -1}, {0, 6, 8, 14, -1, -1, -1, -1}, + {2, 6, 8, 14, -1, -1, -1, -1}, {0, 2, 6, 8, 14, -1, -1, -1}, + {4, 6, 8, 14, -1, -1, -1, -1}, {0, 4, 6, 8, 14, -1, -1, -1}, + {2, 4, 6, 8, 14, -1, -1, -1}, {0, 2, 4, 6, 8, 14, -1, -1}, + {10, 14, -1, -1, -1, -1, -1, -1}, {0, 10, 14, -1, -1, -1, -1, -1}, + {2, 10, 14, -1, -1, -1, -1, -1}, {0, 2, 10, 14, -1, -1, -1, -1}, + {4, 10, 14, -1, -1, -1, -1, -1}, {0, 4, 10, 14, -1, -1, -1, -1}, + {2, 4, 10, 14, -1, -1, -1, -1}, {0, 2, 4, 10, 14, -1, -1, -1}, + {6, 10, 14, -1, -1, -1, -1, -1}, {0, 6, 10, 14, -1, -1, -1, -1}, + {2, 6, 10, 14, -1, -1, -1, -1}, {0, 2, 6, 10, 14, -1, -1, -1}, + {4, 6, 10, 14, -1, -1, -1, -1}, {0, 4, 6, 10, 14, -1, -1, -1}, + {2, 4, 6, 10, 14, -1, -1, -1}, {0, 2, 4, 6, 10, 14, -1, -1}, + {8, 10, 14, -1, -1, -1, -1, -1}, {0, 8, 10, 14, -1, -1, -1, -1}, + {2, 8, 10, 14, -1, -1, -1, -1}, {0, 2, 8, 10, 14, -1, -1, -1}, + {4, 8, 10, 14, -1, -1, -1, -1}, {0, 4, 8, 10, 14, -1, -1, -1}, + {2, 4, 8, 10, 14, -1, -1, -1}, {0, 2, 4, 8, 10, 14, -1, -1}, + {6, 8, 10, 14, -1, -1, -1, -1}, {0, 6, 8, 10, 14, -1, -1, -1}, + {2, 6, 8, 10, 14, -1, -1, -1}, {0, 2, 6, 8, 10, 14, -1, -1}, + {4, 6, 8, 10, 14, -1, -1, -1}, {0, 4, 6, 8, 10, 14, -1, -1}, + {2, 4, 6, 8, 10, 14, -1, -1}, {0, 2, 4, 6, 8, 10, 14, -1}, + {12, 14, -1, -1, -1, -1, -1, -1}, {0, 12, 14, -1, -1, -1, -1, -1}, + {2, 12, 14, -1, -1, -1, -1, -1}, {0, 2, 12, 14, -1, -1, -1, -1}, + {4, 12, 14, -1, -1, -1, -1, -1}, {0, 4, 12, 14, -1, -1, -1, -1}, + {2, 4, 12, 14, -1, -1, -1, -1}, {0, 2, 4, 12, 14, -1, -1, -1}, + {6, 12, 14, -1, -1, -1, -1, -1}, {0, 6, 12, 14, -1, -1, -1, -1}, + {2, 6, 12, 14, -1, -1, -1, -1}, {0, 2, 6, 12, 14, -1, -1, -1}, + {4, 6, 12, 14, -1, -1, -1, -1}, {0, 4, 6, 12, 14, -1, -1, -1}, + {2, 4, 6, 12, 14, -1, -1, -1}, {0, 2, 4, 6, 12, 14, -1, -1}, + {8, 12, 14, -1, -1, -1, -1, -1}, {0, 8, 12, 14, -1, -1, -1, -1}, + {2, 8, 12, 14, -1, -1, -1, -1}, {0, 2, 8, 12, 14, -1, -1, -1}, + {4, 8, 12, 14, -1, -1, -1, -1}, {0, 4, 8, 12, 14, -1, -1, -1}, + {2, 4, 8, 12, 14, -1, -1, -1}, {0, 2, 4, 8, 12, 14, -1, -1}, + {6, 8, 12, 14, -1, -1, -1, -1}, {0, 6, 8, 12, 14, -1, -1, -1}, + {2, 6, 8, 12, 14, -1, -1, -1}, {0, 2, 6, 8, 12, 14, -1, -1}, + {4, 6, 8, 12, 14, -1, -1, -1}, {0, 4, 6, 8, 12, 14, -1, -1}, + {2, 4, 6, 8, 12, 14, -1, -1}, {0, 2, 4, 6, 8, 12, 14, -1}, + {10, 12, 14, -1, -1, -1, -1, -1}, {0, 10, 12, 14, -1, -1, -1, -1}, + {2, 10, 12, 14, -1, -1, -1, -1}, {0, 2, 10, 12, 14, -1, -1, -1}, + {4, 10, 12, 14, -1, -1, -1, -1}, {0, 4, 10, 12, 14, -1, -1, -1}, + {2, 4, 10, 12, 14, -1, -1, -1}, {0, 2, 4, 10, 12, 14, -1, -1}, + {6, 10, 12, 14, -1, -1, -1, -1}, {0, 6, 10, 12, 14, -1, -1, -1}, + {2, 6, 10, 12, 14, -1, -1, -1}, {0, 2, 6, 10, 12, 14, -1, -1}, + {4, 6, 10, 12, 14, -1, -1, -1}, {0, 4, 6, 10, 12, 14, -1, -1}, + {2, 4, 6, 10, 12, 14, -1, -1}, {0, 2, 4, 6, 10, 12, 14, -1}, + {8, 10, 12, 14, -1, -1, -1, -1}, {0, 8, 10, 12, 14, -1, -1, -1}, + {2, 8, 10, 12, 14, -1, -1, -1}, {0, 2, 8, 10, 12, 14, -1, -1}, + {4, 8, 10, 12, 14, -1, -1, -1}, {0, 4, 8, 10, 12, 14, -1, -1}, + {2, 4, 8, 10, 12, 14, -1, -1}, {0, 2, 4, 8, 10, 12, 14, -1}, + {6, 8, 10, 12, 14, -1, -1, -1}, {0, 6, 8, 10, 12, 14, -1, -1}, + {2, 6, 8, 10, 12, 14, -1, -1}, {0, 2, 6, 8, 10, 12, 14, -1}, + {4, 6, 8, 10, 12, 14, -1, -1}, {0, 4, 6, 8, 10, 12, 14, -1}, + {2, 4, 6, 8, 10, 12, 14, -1}, {0, 2, 4, 6, 8, 10, 12, 14}, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_avx2_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_avx2_rej_uniform_table) +int empty_cu_avx2_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.S b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.S similarity index 81% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.S rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.S index 18325ebec0..5e708748a8 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.S @@ -1,9 +1,21 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" -.include "fq.inc" -.include "shuffle.inc" +#include "fq.inc" +#include "shuffle.inc" -/* -nttpack_avx: +.global MLKEM_ASM_NAMESPACE(nttpack_avx2) +MLKEM_ASM_NAMESPACE(nttpack_avx2): #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -45,10 +57,8 @@ vmovdqa %ymm5,192(%rdi) vmovdqa %ymm11,224(%rdi) ret -*/ -.text -nttunpack128_avx: +nttunpack128_avx2: #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -91,11 +101,11 @@ vmovdqa %ymm11,224(%rdi) ret -.global cdecl(nttunpack_avx) -cdecl(nttunpack_avx): -call nttunpack128_avx +.global MLKEM_ASM_NAMESPACE(nttunpack_avx2) +MLKEM_ASM_NAMESPACE(nttunpack_avx2): +call nttunpack128_avx2 add $256,%rdi -call nttunpack128_avx +call nttunpack128_avx2 ret ntttobytes128_avx: @@ -109,16 +119,6 @@ vmovdqa 160(%rsi),%ymm10 vmovdqa 192(%rsi),%ymm11 vmovdqa 224(%rsi),%ymm12 -#csubq -csubq 5,13 -csubq 6,13 -csubq 7,13 -csubq 8,13 -csubq 9,13 -csubq 10,13 -csubq 11,13 -csubq 12,13 - #bitpack vpsllw $12,%ymm6,%ymm4 vpor %ymm4,%ymm5,%ymm4 @@ -168,10 +168,10 @@ vmovdqu %ymm9,160(%rdi) ret -.global cdecl(ntttobytes_avx) -cdecl(ntttobytes_avx): +.global MLKEM_ASM_NAMESPACE(ntttobytes_avx2) +MLKEM_ASM_NAMESPACE(ntttobytes_avx2): #consts -vmovdqa _16XQ*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rdx),%ymm0 call ntttobytes128_avx add $256,%rsi add $192,%rdi @@ -244,12 +244,14 @@ vmovdqa %ymm1,224(%rdi) ret -.global cdecl(nttfrombytes_avx) -cdecl(nttfrombytes_avx): +.global MLKEM_ASM_NAMESPACE(nttfrombytes_avx2) +MLKEM_ASM_NAMESPACE(nttfrombytes_avx2): #consts -vmovdqa _16XMASK*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMASK*2(%rdx),%ymm0 call nttfrombytes128_avx add $256,%rdi add $192,%rsi call nttfrombytes128_avx ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.inc b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.inc similarity index 55% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.inc index 73e9ffe03c..359807bd25 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/shuffle.inc @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + .macro shuffle8 r0,r1,r2,r3 vperm2i128 $0x20,%ymm\r1,%ymm\r0,%ymm\r2 vperm2i128 $0x31,%ymm\r1,%ymm\r0,%ymm\r3 @@ -8,12 +13,19 @@ vpunpcklqdq %ymm\r1,%ymm\r0,%ymm\r2 vpunpckhqdq %ymm\r1,%ymm\r0,%ymm\r3 .endm +/* Shuffle r0=(a0,b0,c0,d0,...), r1=(a1,b1,c1,d1,...) into */ +/* r2 = (a0,b0,a1,b1,e0,f0,e1,f1,...) */ +/* r3 = (c0,d0,c1,d1,g0,h0,g1,h1,...) */ .macro shuffle2 r0,r1,r2,r3 -#vpsllq $32,%ymm\r1,%ymm\r2 +/* r2=(a1,b1,a1,b1,e1,f1,e1,f1,...) */ vmovsldup %ymm\r1,%ymm\r2 +/* Conditional move */ +/* 0xAA = 0b10101010 */ +/* r2=(a0,b0,a1,b1,e0,f0,e1,f1,...) */ vpblendd $0xAA,%ymm\r2,%ymm\r0,%ymm\r2 +/* r0=(c0,d0,0,0,g0,h0,0,0,...) */ vpsrlq $32,%ymm\r0,%ymm\r0 -#vmovshdup %ymm\r0,%ymm\r0 +/* r3=(c0,d0,c1,d1,g0,h0,g1,h1,...) */ vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/x86_64_zetas.i b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/x86_64_zetas.i new file mode 100644 index 0000000000..26d582ee53 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/x86_64/src/x86_64_zetas.i @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +/* + * Table of zeta values used in the AVX2 NTTs + * See autogen for details. + */ + +31498, 31498, 31498, 31498, -758, -758, -758, -758, 0, 0, 0, 0, 0, 0, 0, 0, + 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, + 14745, 14745, 14745, 14745, 14745, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, -359, 13525, 13525, 13525, + 13525, 13525, 13525, 13525, 13525, -12402, -12402, -12402, -12402, -12402, + -12402, -12402, -12402, 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, -20907, -20907, -20907, + -20907, 27758, 27758, 27758, 27758, -3799, -3799, -3799, -3799, -15690, + -15690, -15690, -15690, -171, -171, -171, -171, 622, 622, 622, 622, 1577, + 1577, 1577, 1577, 182, 182, 182, 182, -5827, -5827, 17363, 17363, -26360, + -26360, -29057, -29057, 5571, 5571, -1102, -1102, 21438, 21438, -26242, + -26242, 573, 573, -1325, -1325, 264, 264, 383, 383, -829, -829, 1458, 1458, + -1602, -1602, -130, -130, -5689, -6516, 1496, 30967, -23565, 20179, 20710, + 25080, -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, 1223, 652, + -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, -1618, -1162, + 126, 1469, -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, 32010, + 10631, 29175, -28762, -18486, 17560, -14430, -5276, -1103, 555, -1251, 1550, + 422, 177, -291, 1574, -246, 1159, -777, -602, -1590, -872, 418, -156, 11182, + 13387, -14233, -21655, 13131, -4587, 23092, 5493, -32502, 30317, -18741, + 12639, 20100, 18525, 19529, -12619, 430, 843, 871, 105, 587, -235, -460, + 1653, 778, -147, 1483, 1119, 644, 349, 329, -75, 787, 787, 787, 787, 787, + 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, + -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, 287, 287, + 287, 287, 287, 287, 287, 287, 202, 202, 202, 202, 202, 202, 202, 202, 10690, + 10690, 10690, 10690, 1358, 1358, 1358, 1358, -11202, -11202, -11202, -11202, + 31164, 31164, 31164, 31164, 962, 962, 962, 962, -1202, -1202, -1202, -1202, + -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, -28073, -28073, 24313, + 24313, -10532, -10532, 8800, 8800, 18426, 18426, 8859, 8859, 26675, 26675, + -16163, -16163, -681, -681, 1017, 1017, 732, 732, 608, 608, -1542, -1542, + 411, 411, -205, -205, -1571, -1571, 19883, -28250, -15887, -8898, -28309, + 9075, -30199, 18249, 13426, 14017, -29156, -12757, 16832, 4311, -24155, + -17915, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961, -1508, + -725, 448, -1065, 677, -1275, -31183, 25435, -7382, 24391, -20927, 10946, + 24214, 16989, 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, 817, + 603, 1322, -1465, -1215, 1218, -874, -1187, -1185, -1278, -1510, -870, -108, + 996, 958, 1522, 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, + -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, 1097, 610, + -1285, 384, -136, -1335, 220, -1659, -1530, 794, -854, 478, -308, 991, + -1460, 1628, diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-1024_x86_64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/README.md new file mode 100644 index 0000000000..e499a4a229 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/README.md @@ -0,0 +1,19 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# AArch64 backend (little endian) + +This directory contains a native backend for little endian AArch64 systems. It is derived from the following research +works: + +- _Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1_, Hanno Becker, Vincent Hwang, Matthias + J. Kannwischer, Bo-Yin Yang, and Shang-Yi Yang, [https://eprint.iacr.org/2021/986](https://eprint.iacr.org/2021/986) +- _Fast and Clean: Auditable high-performance assembly via constraint solving_, Amin Abdulrahman, Hanno Becker, Matthias + J. Kannwischer, Fabien Klein, [https://eprint.iacr.org/2022/1303](https://eprint.iacr.org/2022/1303) + +## Profiles + +This backend comes with two profiles: "clean" and optimized. The "clean" backend is handwritten and meant to be easy to +read and modify; for example, is heavily leverages register aliases and assembly macros. The optimized profile is +automatically generated from the clean profile via [SLOTHY](https://github.com/slothy-optimizer/slothy). Currently, the +target architecture is Cortex-A55, but you can easily re-optimize the code for a different microarchitecture supported +by SLOTHY, by adjusting the parameters in [optimize.sh](src/optimize.sh). diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/clean.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/clean.h new file mode 100644 index 0000000000..43a401dfc4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/clean.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_CLEAN + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/clean_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/opt.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/opt.h new file mode 100644 index 0000000000..04323c3e79 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/opt.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_OPT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/opt_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/aarch64_zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/aarch64_zetas.c new file mode 100644 index 0000000000..1e189fd995 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/aarch64_zetas.c @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Table of zeta values used in the AArch64 forward NTT + * See autogen for details. + */ +ALIGN const int16_t aarch64_ntt_zetas_layer01234[] = { + -1600, -15749, -749, -7373, -40, -394, -687, -6762, 630, 6201, + -1432, -14095, 848, 8347, 0, 0, 1062, 10453, 296, 2914, + -882, -8682, 0, 0, -1410, -13879, 1339, 13180, 1476, 14529, + 0, 0, 193, 1900, -283, -2786, 56, 551, 0, 0, + 797, 7845, -1089, -10719, 1333, 13121, 0, 0, -543, -5345, + 1426, 14036, -1235, -12156, 0, 0, -69, -679, 535, 5266, + -447, -4400, 0, 0, 569, 5601, -936, -9213, -450, -4429, + 0, 0, -1583, -15582, -1355, -13338, 821, 8081, 0, 0, +}; + +ALIGN const int16_t aarch64_ntt_zetas_layer56[] = { + 289, 289, 331, 331, -76, -76, -1573, -1573, 2845, + 2845, 3258, 3258, -748, -748, -15483, -15483, 17, 17, + 583, 583, 1637, 1637, -1041, -1041, 167, 167, 5739, + 5739, 16113, 16113, -10247, -10247, -568, -568, -680, -680, + 723, 723, 1100, 1100, -5591, -5591, -6693, -6693, 7117, + 7117, 10828, 10828, 1197, 1197, -1025, -1025, -1052, -1052, + -1274, -1274, 11782, 11782, -10089, -10089, -10355, -10355, -12540, + -12540, 1409, 1409, -48, -48, 756, 756, -314, -314, + 13869, 13869, -472, -472, 7441, 7441, -3091, -3091, -667, + -667, 233, 233, -1173, -1173, -279, -279, -6565, -6565, + 2293, 2293, -11546, -11546, -2746, -2746, 650, 650, -1352, + -1352, -816, -816, 632, 632, 6398, 6398, -13308, -13308, + -8032, -8032, 6221, 6221, -1626, -1626, -540, -540, -1482, + -1482, 1461, 1461, -16005, -16005, -5315, -5315, -14588, -14588, + 14381, 14381, 1651, 1651, -1540, -1540, 952, 952, -642, + -642, 16251, 16251, -15159, -15159, 9371, 9371, -6319, -6319, + -464, -464, 33, 33, 1320, 1320, -1414, -1414, -4567, + -4567, 325, 325, 12993, 12993, -13918, -13918, 939, 939, + -892, -892, 733, 733, 268, 268, 9243, 9243, -8780, + -8780, 7215, 7215, 2638, 2638, -1021, -1021, -941, -941, + -992, -992, 641, 641, -10050, -10050, -9262, -9262, -9764, + -9764, 6309, 6309, -1010, -1010, 1435, 1435, 807, 807, + 452, 452, -9942, -9942, 14125, 14125, 7943, 7943, 4449, + 4449, 1584, 1584, -1292, -1292, 375, 375, -1239, -1239, + 15592, 15592, -12717, -12717, 3691, 3691, -12196, -12196, -1031, + -1031, -109, -109, -780, -780, 1645, 1645, -10148, -10148, + -1073, -1073, -7678, -7678, 16192, 16192, 1438, 1438, -461, + -461, 1534, 1534, -927, -927, 14155, 14155, -4538, -4538, + 15099, 15099, -9125, -9125, 1063, 1063, -556, -556, -1230, + -1230, -863, -863, 10463, 10463, -5473, -5473, -12107, -12107, + -8495, -8495, 319, 319, 757, 757, 561, 561, -735, + -735, 3140, 3140, 7451, 7451, 5522, 5522, -7235, -7235, + -682, -682, -712, -712, 1481, 1481, 648, 648, -6713, + -6713, -7008, -7008, 14578, 14578, 6378, 6378, -525, -525, + 403, 403, 1143, 1143, -554, -554, -5168, -5168, 3967, + 3967, 11251, 11251, -5453, -5453, 1092, 1092, 1026, 1026, + -1179, -1179, 886, 886, 10749, 10749, 10099, 10099, -11605, + -11605, 8721, 8721, -855, -855, -219, -219, 1227, 1227, + 910, 910, -8416, -8416, -2156, -2156, 12078, 12078, 8957, + 8957, -1607, -1607, -1455, -1455, -1219, -1219, 885, 885, + -15818, -15818, -14322, -14322, -11999, -11999, 8711, 8711, 1212, + 1212, 1029, 1029, -394, -394, -1175, -1175, 11930, 11930, + 10129, 10129, -3878, -3878, -11566, -11566, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer01234[] = { + 1583, 15582, -821, -8081, 1355, 13338, 0, 0, -569, -5601, + 450, 4429, 936, 9213, 0, 0, 69, 679, 447, 4400, + -535, -5266, 0, 0, 543, 5345, 1235, 12156, -1426, -14036, + 0, 0, -797, -7845, -1333, -13121, 1089, 10719, 0, 0, + -193, -1900, -56, -551, 283, 2786, 0, 0, 1410, 13879, + -1476, -14529, -1339, -13180, 0, 0, -1062, -10453, 882, 8682, + -296, -2914, 0, 0, 1600, 15749, 40, 394, 749, 7373, + -848, -8347, 1432, 14095, -630, -6201, 687, 6762, 0, 0, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer56[] = { + -910, -910, -1227, -1227, 219, 219, 855, 855, -8957, + -8957, -12078, -12078, 2156, 2156, 8416, 8416, 1175, 1175, + 394, 394, -1029, -1029, -1212, -1212, 11566, 11566, 3878, + 3878, -10129, -10129, -11930, -11930, -885, -885, 1219, 1219, + 1455, 1455, 1607, 1607, -8711, -8711, 11999, 11999, 14322, + 14322, 15818, 15818, -648, -648, -1481, -1481, 712, 712, + 682, 682, -6378, -6378, -14578, -14578, 7008, 7008, 6713, + 6713, -886, -886, 1179, 1179, -1026, -1026, -1092, -1092, + -8721, -8721, 11605, 11605, -10099, -10099, -10749, -10749, 554, + 554, -1143, -1143, -403, -403, 525, 525, 5453, 5453, + -11251, -11251, -3967, -3967, 5168, 5168, 927, 927, -1534, + -1534, 461, 461, -1438, -1438, 9125, 9125, -15099, -15099, + 4538, 4538, -14155, -14155, 735, 735, -561, -561, -757, + -757, -319, -319, 7235, 7235, -5522, -5522, -7451, -7451, + -3140, -3140, 863, 863, 1230, 1230, 556, 556, -1063, + -1063, 8495, 8495, 12107, 12107, 5473, 5473, -10463, -10463, + -452, -452, -807, -807, -1435, -1435, 1010, 1010, -4449, + -4449, -7943, -7943, -14125, -14125, 9942, 9942, -1645, -1645, + 780, 780, 109, 109, 1031, 1031, -16192, -16192, 7678, + 7678, 1073, 1073, 10148, 10148, 1239, 1239, -375, -375, + 1292, 1292, -1584, -1584, 12196, 12196, -3691, -3691, 12717, + 12717, -15592, -15592, 1414, 1414, -1320, -1320, -33, -33, + 464, 464, 13918, 13918, -12993, -12993, -325, -325, 4567, + 4567, -641, -641, 992, 992, 941, 941, 1021, 1021, + -6309, -6309, 9764, 9764, 9262, 9262, 10050, 10050, -268, + -268, -733, -733, 892, 892, -939, -939, -2638, -2638, + -7215, -7215, 8780, 8780, -9243, -9243, -632, -632, 816, + 816, 1352, 1352, -650, -650, -6221, -6221, 8032, 8032, + 13308, 13308, -6398, -6398, 642, 642, -952, -952, 1540, + 1540, -1651, -1651, 6319, 6319, -9371, -9371, 15159, 15159, + -16251, -16251, -1461, -1461, 1482, 1482, 540, 540, 1626, + 1626, -14381, -14381, 14588, 14588, 5315, 5315, 16005, 16005, + 1274, 1274, 1052, 1052, 1025, 1025, -1197, -1197, 12540, + 12540, 10355, 10355, 10089, 10089, -11782, -11782, 279, 279, + 1173, 1173, -233, -233, 667, 667, 2746, 2746, 11546, + 11546, -2293, -2293, 6565, 6565, 314, 314, -756, -756, + 48, 48, -1409, -1409, 3091, 3091, -7441, -7441, 472, + 472, -13869, -13869, 1573, 1573, 76, 76, -331, -331, + -289, -289, 15483, 15483, 748, 748, -3258, -3258, -2845, + -2845, -1100, -1100, -723, -723, 680, 680, 568, 568, + -10828, -10828, -7117, -7117, 6693, 6693, 5591, 5591, 1041, + 1041, -1637, -1637, -583, -583, -17, -17, 10247, 10247, + -16113, -16113, -5739, -5739, -167, -167, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_native[] = { + 17, -17, -568, 568, 583, -583, -680, 680, 1637, -1637, 723, + -723, -1041, 1041, 1100, -1100, 1409, -1409, -667, 667, -48, 48, + 233, -233, 756, -756, -1173, 1173, -314, 314, -279, 279, -1626, + 1626, 1651, -1651, -540, 540, -1540, 1540, -1482, 1482, 952, -952, + 1461, -1461, -642, 642, 939, -939, -1021, 1021, -892, 892, -941, + 941, 733, -733, -992, 992, 268, -268, 641, -641, 1584, -1584, + -1031, 1031, -1292, 1292, -109, 109, 375, -375, -780, 780, -1239, + 1239, 1645, -1645, 1063, -1063, 319, -319, -556, 556, 757, -757, + -1230, 1230, 561, -561, -863, 863, -735, 735, -525, 525, 1092, + -1092, 403, -403, 1026, -1026, 1143, -1143, -1179, 1179, -554, 554, + 886, -886, -1607, 1607, 1212, -1212, -1455, 1455, 1029, -1029, -1219, + 1219, -394, 394, 885, -885, -1175, 1175, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_twisted_native[] = { + 167, -167, -5591, 5591, 5739, -5739, -6693, 6693, 16113, + -16113, 7117, -7117, -10247, 10247, 10828, -10828, 13869, -13869, + -6565, 6565, -472, 472, 2293, -2293, 7441, -7441, -11546, + 11546, -3091, 3091, -2746, 2746, -16005, 16005, 16251, -16251, + -5315, 5315, -15159, 15159, -14588, 14588, 9371, -9371, 14381, + -14381, -6319, 6319, 9243, -9243, -10050, 10050, -8780, 8780, + -9262, 9262, 7215, -7215, -9764, 9764, 2638, -2638, 6309, + -6309, 15592, -15592, -10148, 10148, -12717, 12717, -1073, 1073, + 3691, -3691, -7678, 7678, -12196, 12196, 16192, -16192, 10463, + -10463, 3140, -3140, -5473, 5473, 7451, -7451, -12107, 12107, + 5522, -5522, -8495, 8495, -7235, 7235, -5168, 5168, 10749, + -10749, 3967, -3967, 10099, -10099, 11251, -11251, -11605, 11605, + -5453, 5453, 8721, -8721, -15818, 15818, 11930, -11930, -14322, + 14322, 10129, -10129, -11999, 11999, -3878, 3878, 8711, -8711, + -11566, 11566, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_zetas MLKEM_NAMESPACE(empty_cu_aarch64_zetas) +int empty_cu_aarch64_zetas; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/arith_native_aarch64.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/arith_native_aarch64.h new file mode 100644 index 0000000000..6a5ee8a7d6 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/arith_native_aarch64.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_AARCH64_NATIVE_H +#define MLKEM_AARCH64_NATIVE_H + +#include +#include "common.h" + +#define aarch64_ntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_ntt_zetas_layer01234) +#define aarch64_ntt_zetas_layer56 MLKEM_NAMESPACE(aarch64_ntt_zetas_layer56) +#define aarch64_invntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer01234) +#define aarch64_invntt_zetas_layer56 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer56) +#define aarch64_zetas_mulcache_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_native) +#define aarch64_zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_twisted_native) +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) + +extern const int16_t aarch64_ntt_zetas_layer01234[]; +extern const int16_t aarch64_ntt_zetas_layer56[]; +extern const int16_t aarch64_invntt_zetas_layer01234[]; +extern const int16_t aarch64_invntt_zetas_layer56[]; +extern const int16_t aarch64_zetas_mulcache_native[]; +extern const int16_t aarch64_zetas_mulcache_twisted_native[]; +extern const uint8_t rej_uniform_table[]; + +#define ntt_asm_clean MLKEM_NAMESPACE(ntt_asm_clean) +void ntt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define ntt_asm_opt MLKEM_NAMESPACE(ntt_asm_opt) +void ntt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_clean MLKEM_NAMESPACE(intt_asm_clean) +void intt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_opt MLKEM_NAMESPACE(intt_asm_opt) +void intt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define rej_uniform_asm_clean MLKEM_NAMESPACE(rej_uniform_asm_clean) +unsigned int rej_uniform_asm_clean(int16_t *r, const uint8_t *buf, + unsigned int buflen, const uint8_t *table); + +#define poly_reduce_asm_clean MLKEM_NAMESPACE(poly_reduce_asm_clean) +void poly_reduce_asm_clean(int16_t *); + +#define poly_reduce_asm_opt MLKEM_NAMESPACE(poly_reduce_asm_opt) +void poly_reduce_asm_opt(int16_t *); + +#define poly_tomont_asm_clean MLKEM_NAMESPACE(poly_tomont_asm_clean) +void poly_tomont_asm_clean(int16_t *); + +#define poly_tomont_asm_opt MLKEM_NAMESPACE(poly_tomont_asm_opt) +void poly_tomont_asm_opt(int16_t *); + +#define poly_mulcache_compute_asm_clean \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_clean) +void poly_mulcache_compute_asm_clean(int16_t *, const int16_t *, + const int16_t *, const int16_t *); + + +#define poly_mulcache_compute_asm_opt \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_opt) +void poly_mulcache_compute_asm_opt(int16_t *, const int16_t *, const int16_t *, + const int16_t *); + +#define poly_tobytes_asm_clean MLKEM_NAMESPACE(poly_tobytes_asm_clean) +void poly_tobytes_asm_clean(uint8_t *r, const int16_t *a); + +#define poly_tobytes_asm_opt MLKEM_NAMESPACE(poly_tobytes_asm_opt) +void poly_tobytes_asm_opt(uint8_t *r, const int16_t *a); + +#define polyvec_basemul_acc_montgomery_cached_asm_clean \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) +void polyvec_basemul_acc_montgomery_cached_asm_clean(int16_t *r, + const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#define polyvec_basemul_acc_montgomery_cached_asm_opt \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) +void polyvec_basemul_acc_montgomery_cached_asm_opt(int16_t *r, const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#endif /* MLKEM_AARCH64_NATIVE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/clean_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/clean_impl.h new file mode 100644 index 0000000000..b0ff3d5972 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/clean_impl.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +static INLINE void ntt_native(poly *data) +{ + ntt_asm_clean(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_clean(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_clean(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_clean(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_clean(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_clean( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_clean(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/consts.h new file mode 100644 index 0000000000..c40947299c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/consts.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_AARCH64_CONSTS) +#define MLKEM_NATIVE_AARCH64_CONSTS + +#include +#include "common.h" + +#define zetas_mulcache_native MLKEM_NAMESPACE(zetas_mulcache_native) +extern const int16_t zetas_mulcache_native[256]; + +#define zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(zetas_mulcache_twisted_native) +extern const int16_t zetas_mulcache_twisted_native[256]; + +#endif /* MLKEM_NATIVE_AARCH64_CONSTS */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_clean.S new file mode 100644 index 0000000000..623a82ae9c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_clean.S @@ -0,0 +1,364 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_clean) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_clean): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 +layer3456_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + transpose4 data // manual ld4 + + load_next_roots_56 + + // Layer 7 + gs_butterfly_v data0, data1, root1, root1_tw + gs_butterfly_v data2, data3, root2, root2_tw + // Bounds: + // data0, data2: < 2q + // data1, data3: < q + + // Layer 6 + gs_butterfly_v data0, data2, root0, root0_tw + gs_butterfly_v data1, data3, root0, root0_tw + // Bounds: + // data0: < 4q + // data1: < 2q + // data2, data3: < q + + transpose4 data + + load_next_roots_34 + + // Layer 5 + gs_butterfly data0, data1, root0, 2, 3 + gs_butterfly data2, data3, root0, 4, 5 + // Max bound: 8q + + // Not all of those reductions are needed, but the bounds tracking + // is easier if we uniformly reduce at this point. + barrett_reduce data0 + barrett_reduce data2 + barrett_reduce data1 + barrett_reduce data3 + + // Bounds: q/2 + + // Layer 4 + gs_butterfly data0, data2, root0, 0, 1 + gs_butterfly data1, data3, root0, 0, 1 + // Bounds: < q + + str q_data0, [inp], #(64) + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, layer3456_start + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + +layer012_start: + + ldr q_data0, [in, #0] + ldr q_data1, [in, #(1*(512/8))] + ldr q_data2, [in, #(2*(512/8))] + ldr q_data3, [in, #(3*(512/8))] + ldr q_data4, [in, #(4*(512/8))] + ldr q_data5, [in, #(5*(512/8))] + ldr q_data6, [in, #(6*(512/8))] + ldr q_data7, [in, #(7*(512/8))] + + gs_butterfly data0, data1, root0, 6, 7 + gs_butterfly data2, data3, root1, 0, 1 + gs_butterfly data4, data5, root1, 2, 3 + gs_butterfly data6, data7, root1, 4, 5 + + gs_butterfly data0, data2, root0, 2, 3 + gs_butterfly data1, data3, root0, 2, 3 + gs_butterfly data4, data6, root0, 4, 5 + gs_butterfly data5, data7, root0, 4, 5 + + gs_butterfly data0, data4, root0, 0, 1 + gs_butterfly data1, data5, root0, 0, 1 + gs_butterfly data2, data6, root0, 0, 1 + gs_butterfly data3, data7, root0, 0, 1 + + // Bounds: < 8q + + str q_data4, [in, #(4*(512/8))] + str q_data5, [in, #(5*(512/8))] + str q_data6, [in, #(6*(512/8))] + str q_data7, [in, #(7*(512/8))] + + str q_data0, [in], #(16) + str q_data1, [in, #(-16 + 1*(512/8))] + str q_data2, [in, #(-16 + 2*(512/8))] + str q_data3, [in, #(-16 + 3*(512/8))] + + subs count, count, #1 + cbnz count, layer012_start + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_opt.S new file mode 100644 index 0000000000..e332efef8f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/intt_opt.S @@ -0,0 +1,1020 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_opt) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_opt): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 + // Instructions: 11 + // Expected cycles: 20 + // Expected IPC: 0.55 + // + // Cycle bound: 20.0 + // IPC bound: 0.55 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x3, #0] // *............................. + ldr q8, [x3, #16] // ..*........................... + ldr q24, [x3, #32] // ....*......................... + ldr q16, [x3, #48] // ......*....................... + ldr q9, [x2], #(6*16) // ........*..................... + trn1 v0.4S, v24.4S, v16.4S // ..........*................... + ldr q6, [x2, #-80] // ...........*.................. + ldr q3, [x2, #-64] // .............*................ + ldr q15, [x2, #-48] // ...............*.............. + ldr q4, [x2, #-32] // .................*............ + ldr q28, [x2, #-16] // ...................*.......... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q26, [x3, #0] // *.............................. + // ldr q8, [x3, #16] // ..*............................ + // ldr q24, [x3, #32] // ....*.......................... + // ldr q16, [x3, #48] // ......*........................ + // trn1 v0.4S, v24.4S, v16.4S // ..........*.................... + // ldr q9, [x2], #(6*16) // ........*...................... + // ldr q6, [x2, #-80] // ...........*................... + // ldr q3, [x2, #-64] // .............*................. + // ldr q15, [x2, #-48] // ...............*............... + // ldr q4, [x2, #-32] // .................*............. + // ldr q28, [x2, #-16] // ...................*........... + + sub count, count, #1 +layer3456_start: + // Instructions: 83 + // Expected cycles: 94 + // Expected IPC: 0.88 + // + // Cycle bound: 94.0 + // IPC bound: 0.88 + // + // Wall time: 3.34s + // User time: 3.34s + // + // ------------------------------------- cycle (expected) --------------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------------------ + trn1 v12.4S, v26.4S, v8.4S // *............................................................................................. + trn2 v26.4S, v26.4S, v8.4S // .*............................................................................................ + trn2 v8.4S, v24.4S, v16.4S // ..*........................................................................................... + trn2 v11.2D, v12.2D, v0.2D // ...*.......................................................................................... + trn1 v12.2D, v12.2D, v0.2D // ....*......................................................................................... + trn2 v16.2D, v26.2D, v8.2D // .....*........................................................................................ + trn1 v26.2D, v26.2D, v8.2D // ......*....................................................................................... + sub v8.8H, v11.8H, v16.8H // .......*...................................................................................... + add v11.8H, v11.8H, v16.8H // ........*..................................................................................... + sub v16.8H, v12.8H, v26.8H // .........*.................................................................................... + add v12.8H, v12.8H, v26.8H // ..........*................................................................................... + sqrdmulh v26.8H, v8.8H, v28.8H // ...........*.................................................................................. + sqrdmulh v15.8H, v16.8H, v15.8H // ............*................................................................................. + mul v16.8H, v16.8H, v3.8H // .............*................................................................................ + mul v8.8H, v8.8H, v4.8H // ..............*............................................................................... + sub v0.8H, v12.8H, v11.8H // ...............*.............................................................................. + add v12.8H, v12.8H, v11.8H // ................*............................................................................. + mls v16.8H, v15.8H, v7.H[0] // .................*............................................................................ + mls v8.8H, v26.8H, v7.H[0] // ..................*........................................................................... + sqrdmulh v26.8H, v0.8H, v6.8H // ...................*.......................................................................... + mul v11.8H, v0.8H, v9.8H // ....................*......................................................................... + ldr q15, [x1], #16 // .....................*........................................................................ + sub v0.8H, v16.8H, v8.8H // .......................*...................................................................... + mls v11.8H, v26.8H, v7.H[0] // ........................*..................................................................... + add v26.8H, v16.8H, v8.8H // .........................*.................................................................... + sqrdmulh v8.8H, v0.8H, v6.8H // ..........................*................................................................... + mul v16.8H, v0.8H, v9.8H // ...........................*.................................................................. + trn1 v0.4S, v12.4S, v26.4S // ............................*................................................................. + trn2 v12.4S, v12.4S, v26.4S // .............................*................................................................ + ldr q26, [x3, #64] // ..............................e............................................................... + mls v16.8H, v8.8H, v7.H[0] // ................................*............................................................. + ldr q8, [x3, #80] // .................................e............................................................ + ldr q24, [x3, #96] // ...................................e.......................................................... + trn1 v9.4S, v11.4S, v16.4S // .....................................*........................................................ + trn2 v11.4S, v11.4S, v16.4S // ......................................*....................................................... + ldr q16, [x3, #112] // .......................................e...................................................... + trn2 v6.2D, v0.2D, v9.2D // .........................................*.................................................... + trn2 v3.2D, v12.2D, v11.2D // ..........................................*................................................... + trn1 v0.2D, v0.2D, v9.2D // ...........................................*.................................................. + trn1 v12.2D, v12.2D, v11.2D // ............................................*................................................. + sub v11.8H, v6.8H, v3.8H // .............................................*................................................ + sub v9.8H, v0.8H, v12.8H // ..............................................*............................................... + add v12.8H, v0.8H, v12.8H // ...............................................*.............................................. + sqrdmulh v0.8H, v11.8H, v15.H[5] // ................................................*............................................. + sqrdmulh v4.8H, v9.8H, v15.H[3] // .................................................*............................................ + mul v9.8H, v9.8H, v15.H[2] // ..................................................*........................................... + mul v11.8H, v11.8H, v15.H[4] // ...................................................*.......................................... + add v6.8H, v6.8H, v3.8H // ....................................................*......................................... + sqdmulh v3.8H, v12.8H, v7.H[1] // .....................................................*........................................ + mls v9.8H, v4.8H, v7.H[0] // ......................................................*....................................... + mls v11.8H, v0.8H, v7.H[0] // .......................................................*...................................... + sqdmulh v0.8H, v6.8H, v7.H[1] // ........................................................*..................................... + srshr v3.8H, v3.8H, #11 // .........................................................*.................................... + sqdmulh v4.8H, v9.8H, v7.H[1] // ..........................................................*................................... + sqdmulh v28.8H, v11.8H, v7.H[1] // ...........................................................*.................................. + mls v12.8H, v3.8H, v7.H[0] // ............................................................*................................. + srshr v0.8H, v0.8H, #11 // .............................................................*................................ + srshr v3.8H, v4.8H, #11 // ..............................................................*............................... + srshr v4.8H, v28.8H, #11 // ...............................................................*.............................. + mls v6.8H, v0.8H, v7.H[0] // ................................................................*............................. + mls v9.8H, v3.8H, v7.H[0] // .................................................................*............................ + mls v11.8H, v4.8H, v7.H[0] // ..................................................................*........................... + trn1 v0.4S, v24.4S, v16.4S // ...................................................................e.......................... + sub v3.8H, v12.8H, v6.8H // ....................................................................*......................... + add v12.8H, v12.8H, v6.8H // .....................................................................*........................ + sub v6.8H, v9.8H, v11.8H // ......................................................................*....................... + sqrdmulh v4.8H, v3.8H, v15.H[1] // .......................................................................*...................... + mul v3.8H, v3.8H, v15.H[0] // ........................................................................*..................... + sqrdmulh v28.8H, v6.8H, v15.H[1] // .........................................................................*.................... + mul v15.8H, v6.8H, v15.H[0] // ..........................................................................*................... + add v11.8H, v9.8H, v11.8H // ...........................................................................*.................. + mls v3.8H, v4.8H, v7.H[0] // ............................................................................*................. + str q12, [x3], #(64) // .............................................................................*................ + mls v15.8H, v28.8H, v7.H[0] // ..............................................................................*............... + str q11, [x3, #-48] // ...............................................................................*.............. + ldr q9, [x2], #(6*16) // ................................................................................e............. + str q3, [x3, #-32] // ..................................................................................*........... + ldr q6, [x2, #-80] // ...................................................................................e.......... + str q15, [x3, #-16] // .....................................................................................*........ + ldr q3, [x2, #-64] // ......................................................................................e....... + ldr q15, [x2, #-48] // ........................................................................................e..... + ldr q4, [x2, #-32] // ..........................................................................................e... + ldr q28, [x2, #-16] // ............................................................................................e. + + // ----------------------------------------------------------------- cycle (expected) ------------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------ + // ldr q8, [x3, #(16*0)] // e...............................................................'.............................~....................................................... + // ldr q9, [x3, #(16*1)] // ...e............................................................'................................~.................................................... + // ldr q10, [x3, #(16*2)] // .....e..........................................................'..................................~.................................................. + // ldr q11, [x3, #(16*3)] // .........e......................................................'......................................~.............................................. + // trn1 v25.4s, v8.4s, v9.4s // ................................................................*..................................................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'*.................................................................................... + // trn1 v27.4s, v10.4s, v11.4s // .....................................e..........................'..................................................................~.................. + // trn2 v28.4s, v10.4s, v11.4s // ................................................................'.*................................................................................... + // trn2 v10.2d, v25.2d, v27.2d // ................................................................'..*.................................................................................. + // trn2 v11.2d, v26.2d, v28.2d // ................................................................'....*................................................................................ + // trn1 v8.2d, v25.2d, v27.2d // ................................................................'...*................................................................................. + // trn1 v9.2d, v26.2d, v28.2d // ................................................................'.....*............................................................................... + // ldr q0, [x2], #(6*16) // ..................................................e.............'...............................................................................~..... + // ldr q4, [x2, #(-6*16 + 1*16)] // .....................................................e..........'..................................................................................~.. + // ldr q1, [x2, #(-6*16 + 2*16)] // ........................................................e.......'..................................................................................... + // ldr q5, [x2, #(-6*16 + 3*16)] // ..........................................................e.....'..................................................................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ............................................................e...'..................................................................................... + // ldr q6, [x2, #(-6*16 + 5*16)] // ..............................................................e.'..................................................................................... + // sub v24.8h, v8.8h, v9.8h // ................................................................'........*............................................................................ + // add v8.8h, v8.8h, v9.8h // ................................................................'.........*........................................................................... + // sqrdmulh v27.8h, v24.8h, v5.8h // ................................................................'...........*......................................................................... + // mul v9.8h, v24.8h, v1.8h // ................................................................'............*........................................................................ + // mls v9.8h, v27.8h, v7.h[0] // ................................................................'................*.................................................................... + // sub v24.8h, v10.8h, v11.8h // ................................................................'......*.............................................................................. + // add v10.8h, v10.8h, v11.8h // ................................................................'.......*............................................................................. + // sqrdmulh v27.8h, v24.8h, v6.8h // ................................................................'..........*.......................................................................... + // mul v11.8h, v24.8h, v2.8h // ................................................................'.............*....................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ................................................................'.................*................................................................... + // sub v24.8h, v8.8h, v10.8h // ................................................................'..............*...................................................................... + // add v8.8h, v8.8h, v10.8h // ................................................................'...............*..................................................................... + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'..................*.................................................................. + // mul v10.8h, v24.8h, v0.8h // ................................................................'...................*................................................................. + // mls v10.8h, v27.8h, v7.h[0] // ................................................................'.......................*............................................................. + // sub v24.8h, v9.8h, v11.8h // ................................................................'......................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ................................................................'........................*............................................................ + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'.........................*........................................................... + // mul v11.8h, v24.8h, v0.8h // ................................................................'..........................*.......................................................... + // mls v11.8h, v27.8h, v7.h[0] // ..~.............................................................'...............................*..................................................... + // trn1 v25.4s, v8.4s, v9.4s // ................................................................'...........................*......................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'............................*........................................................ + // trn1 v27.4s, v10.4s, v11.4s // .......~........................................................'....................................*................................................ + // trn2 v28.4s, v10.4s, v11.4s // ........~.......................................................'.....................................*............................................... + // trn2 v10.2d, v25.2d, v27.2d // ...........~....................................................'........................................*............................................ + // trn2 v11.2d, v26.2d, v28.2d // ............~...................................................'.........................................*........................................... + // trn1 v8.2d, v25.2d, v27.2d // .............~..................................................'..........................................*.......................................... + // trn1 v9.2d, v26.2d, v28.2d // ..............~.................................................'...........................................*......................................... + // ldr q0, [x1], #16 // ................................................................'....................*................................................................ + // sub v24.8h, v8.8h, v9.8h // ................~...............................................'.............................................*....................................... + // add v8.8h, v8.8h, v9.8h // .................~..............................................'..............................................*...................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...................~............................................'................................................*.................................... + // mul v9.8h, v24.8h, v0.h[2] // ....................~...........................................'.................................................*................................... + // mls v9.8h, v27.8h, v7.h[0] // ........................~.......................................'.....................................................*............................... + // sub v24.8h, v10.8h, v11.8h // ...............~................................................'............................................*........................................ + // add v10.8h, v10.8h, v11.8h // ......................~.........................................'...................................................*................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ..................~.............................................'...............................................*..................................... + // mul v11.8h, v24.8h, v0.h[4] // .....................~..........................................'..................................................*.................................. + // mls v11.8h, v27.8h, v7.h[0] // .........................~......................................'......................................................*.............................. + // sqdmulh v25.8h, v8.8h, v7.h[1] // .......................~........................................'....................................................*................................ + // srshr v25.8h, v25.8h, #11 // ...........................~....................................'........................................................*............................ + // mls v8.8h, v25.8h, v7.h[0] // ..............................~.................................'...........................................................*......................... + // sqdmulh v25.8h, v10.8h, v7.h[1] // ..........................~.....................................'.......................................................*............................. + // srshr v25.8h, v25.8h, #11 // ...............................~................................'............................................................*........................ + // mls v10.8h, v25.8h, v7.h[0] // ..................................~.............................'...............................................................*..................... + // sqdmulh v25.8h, v9.8h, v7.h[1] // ............................~...................................'.........................................................*........................... + // srshr v25.8h, v25.8h, #11 // ................................~...............................'.............................................................*....................... + // mls v9.8h, v25.8h, v7.h[0] // ...................................~............................'................................................................*.................... + // sqdmulh v25.8h, v11.8h, v7.h[1] // .............................~..................................'..........................................................*.......................... + // srshr v25.8h, v25.8h, #11 // .................................~..............................'..............................................................*...................... + // mls v11.8h, v25.8h, v7.h[0] // ....................................~...........................'.................................................................*................... + // sub v24.8h, v8.8h, v10.8h // ......................................~.........................'...................................................................*................. + // add v8.8h, v8.8h, v10.8h // .......................................~........................'....................................................................*................ + // sqrdmulh v27.8h, v24.8h, v0.h[1] // .........................................~......................'......................................................................*.............. + // mul v10.8h, v24.8h, v0.h[0] // ..........................................~.....................'.......................................................................*............. + // mls v10.8h, v27.8h, v7.h[0] // ..............................................~.................'...........................................................................*......... + // sub v24.8h, v9.8h, v11.8h // ........................................~.......................'.....................................................................*............... + // add v9.8h, v9.8h, v11.8h // .............................................~..................'..........................................................................*.......... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................................~....................'........................................................................*............ + // mul v11.8h, v24.8h, v0.h[0] // ............................................~...................'.........................................................................*........... + // mls v11.8h, v27.8h, v7.h[0] // ................................................~...............'.............................................................................*....... + // str q8, [x3], #(64) // ...............................................~................'............................................................................*........ + // str q9, [x3, #(-64 + 16*1)] // .................................................~..............'..............................................................................*...... + // str q10, [x3, #(-64 + 16*2)] // ....................................................~...........'.................................................................................*... + // str q11, [x3, #(-64 + 16*3)] // .......................................................~........'....................................................................................* + + sub count, count, #1 + cbnz count, layer3456_start + // Instructions: 72 + // Expected cycles: 79 + // Expected IPC: 0.91 + // + // Cycle bound: 79.0 + // IPC bound: 0.91 + // + // Wall time: 9.28s + // User time: 9.28s + // + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + trn1 v11.4S, v26.4S, v8.4S // *.............................................................................. + trn2 v24.4S, v24.4S, v16.4S // .*............................................................................. + trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + trn1 v18.2D, v11.2D, v0.2D // ...*........................................................................... + trn2 v11.2D, v11.2D, v0.2D // ....*.......................................................................... + trn2 v12.2D, v26.2D, v24.2D // .....*......................................................................... + trn1 v8.2D, v26.2D, v24.2D // ......*........................................................................ + sub v26.8H, v11.8H, v12.8H // .......*....................................................................... + sub v13.8H, v18.8H, v8.8H // ........*...................................................................... + add v24.8H, v18.8H, v8.8H // .........*..................................................................... + mul v16.8H, v26.8H, v4.8H // ..........*.................................................................... + sqrdmulh v17.8H, v13.8H, v15.8H // ...........*................................................................... + mul v3.8H, v13.8H, v3.8H // ............*.................................................................. + sqrdmulh v26.8H, v26.8H, v28.8H // .............*................................................................. + add v10.8H, v11.8H, v12.8H // ..............*................................................................ + mls v3.8H, v17.8H, v7.H[0] // ................*.............................................................. + mls v16.8H, v26.8H, v7.H[0] // .................*............................................................. + sub v26.8H, v24.8H, v10.8H // ..................*............................................................ + ldr q4, [x1], #16 // ...................*........................................................... + sub v12.8H, v3.8H, v16.8H // .....................*......................................................... + sqrdmulh v15.8H, v26.8H, v6.8H // ......................*........................................................ + mul v11.8H, v26.8H, v9.8H // .......................*....................................................... + mul v8.8H, v12.8H, v9.8H // ........................*...................................................... + sqrdmulh v12.8H, v12.8H, v6.8H // .........................*..................................................... + add v0.8H, v24.8H, v10.8H // ..........................*.................................................... + mls v11.8H, v15.8H, v7.H[0] // ...........................*................................................... + add v6.8H, v3.8H, v16.8H // ............................*.................................................. + mls v8.8H, v12.8H, v7.H[0] // .............................*................................................. + trn2 v26.4S, v0.4S, v6.4S // ...............................*............................................... + trn2 v12.4S, v11.4S, v8.4S // .................................*............................................. + trn1 v3.4S, v11.4S, v8.4S // ..................................*............................................ + trn1 v17.4S, v0.4S, v6.4S // ...................................*........................................... + trn1 v8.2D, v26.2D, v12.2D // ....................................*.......................................... + trn2 v13.2D, v26.2D, v12.2D // .....................................*......................................... + trn1 v11.2D, v17.2D, v3.2D // ......................................*........................................ + trn2 v15.2D, v17.2D, v3.2D // .......................................*....................................... + sub v12.8H, v11.8H, v8.8H // ........................................*...................................... + add v16.8H, v15.8H, v13.8H // .........................................*..................................... + sub v26.8H, v15.8H, v13.8H // ..........................................*.................................... + mul v0.8H, v12.8H, v4.H[2] // ...........................................*................................... + sqrdmulh v9.8H, v12.8H, v4.H[3] // ............................................*.................................. + mul v13.8H, v26.8H, v4.H[4] // .............................................*................................. + sqrdmulh v26.8H, v26.8H, v4.H[5] // ..............................................*................................ + add v24.8H, v11.8H, v8.8H // ...............................................*............................... + mls v0.8H, v9.8H, v7.H[0] // ................................................*.............................. + sqdmulh v12.8H, v16.8H, v7.H[1] // .................................................*............................. + mls v13.8H, v26.8H, v7.H[0] // ..................................................*............................ + sqdmulh v11.8H, v24.8H, v7.H[1] // ...................................................*........................... + sqdmulh v8.8H, v0.8H, v7.H[1] // ....................................................*.......................... + srshr v12.8H, v12.8H, #11 // .....................................................*......................... + sqdmulh v26.8H, v13.8H, v7.H[1] // ......................................................*........................ + srshr v11.8H, v11.8H, #11 // .......................................................*....................... + mls v16.8H, v12.8H, v7.H[0] // ........................................................*...................... + srshr v8.8H, v8.8H, #11 // .........................................................*..................... + srshr v26.8H, v26.8H, #11 // ..........................................................*.................... + mls v24.8H, v11.8H, v7.H[0] // ...........................................................*................... + mls v0.8H, v8.8H, v7.H[0] // ............................................................*.................. + mls v13.8H, v26.8H, v7.H[0] // .............................................................*................. + sub v26.8H, v24.8H, v16.8H // ...............................................................*............... + add v15.8H, v24.8H, v16.8H // ................................................................*.............. + sub v12.8H, v0.8H, v13.8H // .................................................................*............. + mul v11.8H, v26.8H, v4.H[0] // ..................................................................*............ + sqrdmulh v16.8H, v26.8H, v4.H[1] // ...................................................................*........... + mul v26.8H, v12.8H, v4.H[0] // ....................................................................*.......... + sqrdmulh v8.8H, v12.8H, v4.H[1] // .....................................................................*......... + add v12.8H, v0.8H, v13.8H // ......................................................................*........ + mls v11.8H, v16.8H, v7.H[0] // .......................................................................*....... + str q15, [x3], #(64) // ........................................................................*...... + mls v26.8H, v8.8H, v7.H[0] // .........................................................................*..... + str q12, [x3, #-48] // ..........................................................................*.... + str q11, [x3, #-32] // ............................................................................*.. + str q26, [x3, #-16] // ..............................................................................* + + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + // trn1 v12.4S, v26.4S, v8.4S // *.............................................................................. + // trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + // trn2 v8.4S, v24.4S, v16.4S // .*............................................................................. + // trn2 v11.2D, v12.2D, v0.2D // ....*.......................................................................... + // trn1 v12.2D, v12.2D, v0.2D // ...*........................................................................... + // trn2 v16.2D, v26.2D, v8.2D // .....*......................................................................... + // trn1 v26.2D, v26.2D, v8.2D // ......*........................................................................ + // sub v8.8H, v11.8H, v16.8H // .......*....................................................................... + // add v11.8H, v11.8H, v16.8H // ..............*................................................................ + // sub v16.8H, v12.8H, v26.8H // ........*...................................................................... + // add v12.8H, v12.8H, v26.8H // .........*..................................................................... + // sqrdmulh v26.8H, v8.8H, v28.8H // .............*................................................................. + // sqrdmulh v15.8H, v16.8H, v15.8H // ...........*................................................................... + // mul v16.8H, v16.8H, v3.8H // ............*.................................................................. + // mul v8.8H, v8.8H, v4.8H // ..........*.................................................................... + // sub v0.8H, v12.8H, v11.8H // ..................*............................................................ + // add v12.8H, v12.8H, v11.8H // ..........................*.................................................... + // mls v16.8H, v15.8H, v7.H[0] // ................*.............................................................. + // mls v8.8H, v26.8H, v7.H[0] // .................*............................................................. + // sqrdmulh v26.8H, v0.8H, v6.8H // ......................*........................................................ + // mul v11.8H, v0.8H, v9.8H // .......................*....................................................... + // ldr q15, [x1], #16 // ...................*........................................................... + // sub v0.8H, v16.8H, v8.8H // .....................*......................................................... + // mls v11.8H, v26.8H, v7.H[0] // ...........................*................................................... + // add v26.8H, v16.8H, v8.8H // ............................*.................................................. + // sqrdmulh v8.8H, v0.8H, v6.8H // .........................*..................................................... + // mul v16.8H, v0.8H, v9.8H // ........................*...................................................... + // trn1 v0.4S, v12.4S, v26.4S // ...................................*........................................... + // trn2 v12.4S, v12.4S, v26.4S // ...............................*............................................... + // mls v16.8H, v8.8H, v7.H[0] // .............................*................................................. + // trn1 v9.4S, v11.4S, v16.4S // ..................................*............................................ + // trn2 v11.4S, v11.4S, v16.4S // .................................*............................................. + // trn2 v6.2D, v0.2D, v9.2D // .......................................*....................................... + // trn2 v3.2D, v12.2D, v11.2D // .....................................*......................................... + // trn1 v0.2D, v0.2D, v9.2D // ......................................*........................................ + // trn1 v12.2D, v12.2D, v11.2D // ....................................*.......................................... + // sub v11.8H, v6.8H, v3.8H // ..........................................*.................................... + // sub v9.8H, v0.8H, v12.8H // ........................................*...................................... + // add v12.8H, v0.8H, v12.8H // ...............................................*............................... + // sqrdmulh v0.8H, v11.8H, v15.H[5] // ..............................................*................................ + // sqrdmulh v4.8H, v9.8H, v15.H[3] // ............................................*.................................. + // mul v9.8H, v9.8H, v15.H[2] // ...........................................*................................... + // mul v11.8H, v11.8H, v15.H[4] // .............................................*................................. + // add v6.8H, v6.8H, v3.8H // .........................................*..................................... + // sqdmulh v3.8H, v12.8H, v7.H[1] // ...................................................*........................... + // mls v9.8H, v4.8H, v7.H[0] // ................................................*.............................. + // mls v11.8H, v0.8H, v7.H[0] // ..................................................*............................ + // sqdmulh v0.8H, v6.8H, v7.H[1] // .................................................*............................. + // srshr v3.8H, v3.8H, #11 // .......................................................*....................... + // sqdmulh v4.8H, v9.8H, v7.H[1] // ....................................................*.......................... + // sqdmulh v28.8H, v11.8H, v7.H[1] // ......................................................*........................ + // mls v12.8H, v3.8H, v7.H[0] // ...........................................................*................... + // srshr v0.8H, v0.8H, #11 // .....................................................*......................... + // srshr v3.8H, v4.8H, #11 // .........................................................*..................... + // srshr v4.8H, v28.8H, #11 // ..........................................................*.................... + // mls v6.8H, v0.8H, v7.H[0] // ........................................................*...................... + // mls v9.8H, v3.8H, v7.H[0] // ............................................................*.................. + // mls v11.8H, v4.8H, v7.H[0] // .............................................................*................. + // sub v3.8H, v12.8H, v6.8H // ...............................................................*............... + // add v12.8H, v12.8H, v6.8H // ................................................................*.............. + // sub v6.8H, v9.8H, v11.8H // .................................................................*............. + // sqrdmulh v4.8H, v3.8H, v15.H[1] // ...................................................................*........... + // mul v3.8H, v3.8H, v15.H[0] // ..................................................................*............ + // sqrdmulh v28.8H, v6.8H, v15.H[1] // .....................................................................*......... + // mul v15.8H, v6.8H, v15.H[0] // ....................................................................*.......... + // add v11.8H, v9.8H, v11.8H // ......................................................................*........ + // mls v3.8H, v4.8H, v7.H[0] // .......................................................................*....... + // str q12, [x3], #(64) // ........................................................................*...... + // mls v15.8H, v28.8H, v7.H[0] // .........................................................................*..... + // str q11, [x3, #-48] // ..........................................................................*.... + // str q3, [x3, #-32] // ............................................................................*.. + // str q15, [x3, #-16] // ..............................................................................* + + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + + // Instructions: 12 + // Expected cycles: 19 + // Expected IPC: 0.63 + // + // Cycle bound: 19.0 + // IPC bound: 0.63 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q24, [x0, #128] // *............................. + ldr q16, [x0, #192] // ..*........................... + ldr q9, [x0, #256] // ....*......................... + ldr q6, [x0, #320] // ......*....................... + ldr q3, [x0, #384] // ........*..................... + ldr q4, [x0, #448] // ..........*................... + add v28.8H, v9.8H, v6.8H // ............*................. + add v19.8H, v24.8H, v16.8H // .............*................ + add v13.8H, v3.8H, v4.8H // ..............*............... + ldr q11, [x0, #0] // ...............*.............. + add v23.8H, v28.8H, v13.8H // .................*............ + ldr q15, [x0, #64] // ..................*........... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q11, [x0, #0] // ...............*............... + // ldr q15, [x0, #64] // ..................*............ + // ldr q24, [x0, #128] // *.............................. + // ldr q16, [x0, #192] // ..*............................ + // ldr q9, [x0, #256] // ....*.......................... + // ldr q6, [x0, #320] // ......*........................ + // ldr q3, [x0, #384] // ........*...................... + // ldr q4, [x0, #448] // ..........*.................... + // add v28.8H, v9.8H, v6.8H // ............*.................. + // add v13.8H, v3.8H, v4.8H // ..............*................ + // add v19.8H, v24.8H, v16.8H // .............*................. + // add v23.8H, v28.8H, v13.8H // .................*............. + + sub count, count, #1 +layer012_start: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.81s + // User time: 2.81s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sub v12.8H, v11.8H, v15.8H // *................................................................................... + add v26.8H, v11.8H, v15.8H // .*.................................................................................. + sub v8.8H, v24.8H, v16.8H // ..*................................................................................. + sqrdmulh v11.8H, v12.8H, v0.H[7] // ...*................................................................................ + mul v12.8H, v12.8H, v0.H[6] // ....*............................................................................... + sub v16.8H, v26.8H, v19.8H // .....*.............................................................................. + add v26.8H, v26.8H, v19.8H // ......*............................................................................. + sqrdmulh v15.8H, v8.8H, v1.H[1] // .......*............................................................................ + mul v8.8H, v8.8H, v1.H[0] // ........*........................................................................... + mls v12.8H, v11.8H, v7.H[0] // .........*.......................................................................... + sub v11.8H, v9.8H, v6.8H // ..........*......................................................................... + sqrdmulh v24.8H, v16.8H, v0.H[3] // ...........*........................................................................ + mul v16.8H, v16.8H, v0.H[2] // ............*....................................................................... + sub v9.8H, v26.8H, v23.8H // .............*...................................................................... + add v26.8H, v26.8H, v23.8H // ..............*..................................................................... + mls v8.8H, v15.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v15.8H, v11.8H, v1.H[3] // ................*................................................................... + mul v11.8H, v11.8H, v1.H[2] // .................*.................................................................. + sub v6.8H, v3.8H, v4.8H // ..................*................................................................. + sub v3.8H, v12.8H, v8.8H // ...................*................................................................ + add v12.8H, v12.8H, v8.8H // ....................*............................................................... + mls v11.8H, v15.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v8.8H, v6.8H, v1.H[5] // ......................*............................................................. + mls v16.8H, v24.8H, v7.H[0] // .......................*............................................................ + mul v15.8H, v6.8H, v1.H[4] // ........................*........................................................... + sqrdmulh v24.8H, v3.8H, v0.H[3] // .........................*.......................................................... + mul v6.8H, v3.8H, v0.H[2] // ..........................*......................................................... + sqrdmulh v3.8H, v9.8H, v0.H[1] // ...........................*........................................................ + mul v9.8H, v9.8H, v0.H[0] // ............................*....................................................... + str q26, [x0], #(16) // .............................*...................................................... + mls v15.8H, v8.8H, v7.H[0] // ..............................*..................................................... + mls v6.8H, v24.8H, v7.H[0] // ...............................*.................................................... + sub v26.8H, v28.8H, v13.8H // ................................*................................................... + mls v9.8H, v3.8H, v7.H[0] // .................................*.................................................. + sub v8.8H, v11.8H, v15.8H // ..................................*................................................. + sqrdmulh v24.8H, v26.8H, v0.H[5] // ...................................*................................................ + mul v26.8H, v26.8H, v0.H[4] // ....................................*............................................... + add v11.8H, v11.8H, v15.8H // .....................................*.............................................. + sqrdmulh v15.8H, v8.8H, v0.H[5] // ......................................*............................................. + mul v8.8H, v8.8H, v0.H[4] // .......................................*............................................ + mls v26.8H, v24.8H, v7.H[0] // ........................................*........................................... + sub v24.8H, v12.8H, v11.8H // .........................................*.......................................... + add v12.8H, v12.8H, v11.8H // ..........................................*......................................... + mls v8.8H, v15.8H, v7.H[0] // ...........................................*........................................ + sqrdmulh v11.8H, v24.8H, v0.H[1] // ............................................*....................................... + mul v15.8H, v24.8H, v0.H[0] // .............................................*...................................... + sub v24.8H, v16.8H, v26.8H // ..............................................*..................................... + add v26.8H, v16.8H, v26.8H // ...............................................*.................................... + sub v16.8H, v6.8H, v8.8H // ................................................*................................... + mls v15.8H, v11.8H, v7.H[0] // .................................................*.................................. + sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................................................*................................. + mul v24.8H, v24.8H, v0.H[0] // ...................................................*................................ + add v8.8H, v6.8H, v8.8H // ....................................................*............................... + sqrdmulh v6.8H, v16.8H, v0.H[1] // .....................................................*.............................. + mul v16.8H, v16.8H, v0.H[0] // ......................................................*............................. + mls v24.8H, v11.8H, v7.H[0] // .......................................................*............................ + str q9, [x0, #240] // ........................................................*........................... + ldr q11, [x0, #0] // .........................................................e.......................... + mls v16.8H, v6.8H, v7.H[0] // ...........................................................*........................ + str q15, [x0, #304] // ............................................................*....................... + ldr q15, [x0, #64] // .............................................................e...................... + str q24, [x0, #368] // ...............................................................*.................... + ldr q24, [x0, #128] // ................................................................e................... + str q16, [x0, #432] // ..................................................................*................. + ldr q16, [x0, #192] // ...................................................................e................ + str q12, [x0, #48] // .....................................................................*.............. + ldr q9, [x0, #256] // ......................................................................e............. + ldr q6, [x0, #320] // ........................................................................e........... + ldr q3, [x0, #384] // ..........................................................................e......... + ldr q4, [x0, #448] // ............................................................................e....... + str q26, [x0, #112] // ..............................................................................*..... + add v28.8H, v9.8H, v6.8H // ...............................................................................e.... + add v13.8H, v3.8H, v4.8H // ................................................................................e... + str q8, [x0, #176] // .................................................................................*.. + add v19.8H, v24.8H, v16.8H // ..................................................................................e. + add v23.8H, v28.8H, v13.8H // ...................................................................................e + + // --------------------------------------------- cycle (expected) ---------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|-------- + // ldr q8, [x0, #0] // e..........................'........................................................~........................ + // ldr q9, [x0, #(1*(512/8))] // ....e......................'............................................................~.................... + // ldr q10, [x0, #(2*(512/8))] // .......e...................'...............................................................~................. + // ldr q11, [x0, #(3*(512/8))] // ..........e................'..................................................................~.............. + // ldr q12, [x0, #(4*(512/8))] // .............e.............'.....................................................................~........... + // ldr q13, [x0, #(5*(512/8))] // ...............e...........'.......................................................................~......... + // ldr q14, [x0, #(6*(512/8))] // .................e.........'.........................................................................~....... + // ldr q15, [x0, #(7*(512/8))] // ...................e.......'...........................................................................~..... + // sub v24.8h, v8.8h, v9.8h // ...........................*................................................................................. + // add v8.8h, v8.8h, v9.8h // ...........................'*................................................................................ + // sqrdmulh v27.8h, v24.8h, v0.h[7] // ...........................'..*.............................................................................. + // mul v9.8h, v24.8h, v0.h[6] // ...........................'...*............................................................................. + // mls v9.8h, v27.8h, v7.h[0] // ...........................'........*........................................................................ + // sub v24.8h, v10.8h, v11.8h // ...........................'.*............................................................................... + // add v10.8h, v10.8h, v11.8h // .........................e.'................................................................................. + // sqrdmulh v27.8h, v24.8h, v1.h[1] // ...........................'......*.......................................................................... + // mul v11.8h, v24.8h, v1.h[0] // ...........................'.......*......................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............*.................................................................. + // sub v24.8h, v12.8h, v13.8h // ...........................'.........*....................................................................... + // add v12.8h, v12.8h, v13.8h // ......................e....'..............................................................................~.. + // sqrdmulh v27.8h, v24.8h, v1.h[3] // ...........................'...............*................................................................. + // mul v13.8h, v24.8h, v1.h[2] // ...........................'................*................................................................ + // mls v13.8h, v27.8h, v7.h[0] // ...........................'....................*............................................................ + // sub v24.8h, v14.8h, v15.8h // ...........................'.................*............................................................... + // add v14.8h, v14.8h, v15.8h // .......................e...'...............................................................................~. + // sqrdmulh v27.8h, v24.8h, v1.h[5] // ...........................'.....................*........................................................... + // mul v15.8h, v24.8h, v1.h[4] // ...........................'.......................*......................................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................... + // sub v24.8h, v8.8h, v10.8h // ...........................'....*............................................................................ + // add v8.8h, v8.8h, v10.8h // ...........................'.....*........................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'..........*...................................................................... + // mul v10.8h, v24.8h, v0.h[2] // ...........................'...........*..................................................................... + // mls v10.8h, v27.8h, v7.h[0] // ...........................'......................*.......................................................... + // sub v24.8h, v9.8h, v11.8h // ...........................'..................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ...........................'...................*............................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'........................*........................................................ + // mul v11.8h, v24.8h, v0.h[2] // ...........................'.........................*....................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............................*.................................................. + // sub v24.8h, v12.8h, v14.8h // ...........................'...............................*................................................. + // add v12.8h, v12.8h, v14.8h // ..........................e'................................................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'..................................*.............................................. + // mul v14.8h, v24.8h, v0.h[4] // ...........................'...................................*............................................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'.......................................*......................................... + // sub v24.8h, v13.8h, v15.8h // ...........................'.................................*............................................... + // add v13.8h, v13.8h, v15.8h // ...........................'....................................*............................................ + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'.....................................*........................................... + // mul v15.8h, v24.8h, v0.h[4] // ...........................'......................................*.......................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'..........................................*...................................... + // sub v24.8h, v8.8h, v12.8h // ...........................'............*.................................................................... + // add v8.8h, v8.8h, v12.8h // ...........................'.............*................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'..........................*...................................................... + // mul v12.8h, v24.8h, v0.h[0] // ...........................'...........................*..................................................... + // mls v12.8h, v27.8h, v7.h[0] // ...........................'................................*................................................ + // sub v24.8h, v9.8h, v13.8h // ...........................'........................................*........................................ + // add v9.8h, v9.8h, v13.8h // ...........................'.........................................*....................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'...........................................*..................................... + // mul v13.8h, v24.8h, v0.h[0] // ...........................'............................................*.................................... + // mls v13.8h, v27.8h, v7.h[0] // ...........................'................................................*................................ + // sub v24.8h, v10.8h, v14.8h // ...........................'.............................................*................................... + // add v10.8h, v10.8h, v14.8h // ...........................'..............................................*.................................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'.................................................*............................... + // mul v14.8h, v24.8h, v0.h[0] // ...........................'..................................................*.............................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'......................................................*.......................... + // sub v24.8h, v11.8h, v15.8h // ...........................'...............................................*................................. + // add v11.8h, v11.8h, v15.8h // ...........................'...................................................*............................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'....................................................*............................ + // mul v15.8h, v24.8h, v0.h[0] // ...........................'.....................................................*........................... + // mls v15.8h, v27.8h, v7.h[0] // ..~........................'..........................................................*...................... + // str q12, [x0, #(4*(512/8))] // ...........................'.......................................................*......................... + // str q13, [x0, #(5*(512/8))] // ...~.......................'...........................................................*..................... + // str q14, [x0, #(6*(512/8))] // ......~....................'..............................................................*.................. + // str q15, [x0, #(7*(512/8))] // .........~.................'.................................................................*............... + // str q8, [x0], #(16) // ...........................'............................*.................................................... + // str q9, [x0, #(-16 + 1*(512/8))] // ............~..............'....................................................................*............ + // str q10, [x0, #(-16 + 2*(512/8))] // .....................~.....'.............................................................................*... + // str q11, [x0, #(-16 + 3*(512/8))] // ........................~..'................................................................................* + + sub count, count, #1 + cbnz count, layer012_start + // Instructions: 64 + // Expected cycles: 66 + // Expected IPC: 0.97 + // + // Cycle bound: 66.0 + // IPC bound: 0.97 + // + // Wall time: 8.33s + // User time: 8.33s + // + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + add v10.8H, v11.8H, v15.8H // *................................................................. + sub v12.8H, v28.8H, v13.8H // .*................................................................ + sub v11.8H, v11.8H, v15.8H // ..*............................................................... + sub v22.8H, v10.8H, v19.8H // ...*.............................................................. + mul v18.8H, v12.8H, v0.H[4] // ....*............................................................. + sqrdmulh v26.8H, v12.8H, v0.H[5] // .....*............................................................ + sqrdmulh v12.8H, v22.8H, v0.H[3] // ......*........................................................... + mul v13.8H, v22.8H, v0.H[2] // .......*.......................................................... + sub v31.8H, v24.8H, v16.8H // ........*......................................................... + sqrdmulh v22.8H, v11.8H, v0.H[7] // .........*........................................................ + mls v18.8H, v26.8H, v7.H[0] // ..........*....................................................... + mls v13.8H, v12.8H, v7.H[0] // ...........*...................................................... + sqrdmulh v2.8H, v31.8H, v1.H[1] // ............*..................................................... + mul v5.8H, v31.8H, v1.H[0] // .............*.................................................... + mul v15.8H, v11.8H, v0.H[6] // ..............*................................................... + sub v12.8H, v13.8H, v18.8H // ...............*.................................................. + sub v4.8H, v3.8H, v4.8H // ................*................................................. + mls v5.8H, v2.8H, v7.H[0] // .................*................................................ + sqrdmulh v26.8H, v12.8H, v0.H[1] // ..................*............................................... + mul v12.8H, v12.8H, v0.H[0] // ...................*.............................................. + mls v15.8H, v22.8H, v7.H[0] // ....................*............................................. + sqrdmulh v8.8H, v4.8H, v1.H[5] // .....................*............................................ + mul v4.8H, v4.8H, v1.H[4] // ......................*........................................... + mls v12.8H, v26.8H, v7.H[0] // .......................*.......................................... + sub v21.8H, v15.8H, v5.8H // ........................*......................................... + sub v28.8H, v9.8H, v6.8H // .........................*........................................ + mls v4.8H, v8.8H, v7.H[0] // ..........................*....................................... + mul v24.8H, v21.8H, v0.H[2] // ...........................*...................................... + sqrdmulh v8.8H, v21.8H, v0.H[3] // ............................*..................................... + sqrdmulh v6.8H, v28.8H, v1.H[3] // .............................*.................................... + add v19.8H, v10.8H, v19.8H // ..............................*................................... + mul v28.8H, v28.8H, v1.H[2] // ...............................*.................................. + mls v24.8H, v8.8H, v7.H[0] // ................................*................................. + sub v11.8H, v19.8H, v23.8H // .................................*................................ + str q12, [x0, #384] // ..................................*............................... + mls v28.8H, v6.8H, v7.H[0] // ...................................*.............................. + sqrdmulh v16.8H, v11.8H, v0.H[1] // ....................................*............................. + mul v9.8H, v11.8H, v0.H[0] // .....................................*............................ + add v6.8H, v15.8H, v5.8H // ......................................*........................... + add v26.8H, v28.8H, v4.8H // .......................................*.......................... + sub v15.8H, v28.8H, v4.8H // ........................................*......................... + mls v9.8H, v16.8H, v7.H[0] // .........................................*........................ + add v3.8H, v6.8H, v26.8H // ..........................................*....................... + mul v8.8H, v15.8H, v0.H[4] // ...........................................*...................... + sqrdmulh v15.8H, v15.8H, v0.H[5] // ............................................*..................... + str q9, [x0, #256] // .............................................*.................... + sub v2.8H, v6.8H, v26.8H // ..............................................*................... + str q3, [x0, #64] // ...............................................*.................. + mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + sqrdmulh v15.8H, v2.8H, v0.H[1] // .................................................*................ + mul v11.8H, v2.8H, v0.H[0] // ..................................................*............... + add v16.8H, v13.8H, v18.8H // ...................................................*.............. + sub v12.8H, v24.8H, v8.8H // ....................................................*............. + add v8.8H, v24.8H, v8.8H // .....................................................*............ + mls v11.8H, v15.8H, v7.H[0] // ......................................................*........... + sqrdmulh v26.8H, v12.8H, v0.H[1] // .......................................................*.......... + mul v12.8H, v12.8H, v0.H[0] // ........................................................*......... + str q8, [x0, #192] // .........................................................*........ + add v15.8H, v19.8H, v23.8H // ..........................................................*....... + str q11, [x0, #320] // ...........................................................*...... + mls v12.8H, v26.8H, v7.H[0] // ............................................................*..... + str q15, [x0], #(16) // .............................................................*.... + str q16, [x0, #112] // ...............................................................*.. + str q12, [x0, #432] // .................................................................* + + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + // sub v12.8H, v11.8H, v15.8H // ..*............................................................... + // add v26.8H, v11.8H, v15.8H // *................................................................. + // sub v8.8H, v24.8H, v16.8H // ........*......................................................... + // sqrdmulh v11.8H, v12.8H, v0.H[7] // .........*........................................................ + // mul v12.8H, v12.8H, v0.H[6] // ..............*................................................... + // sub v16.8H, v26.8H, v19.8H // ...*.............................................................. + // add v26.8H, v26.8H, v19.8H // ..............................*................................... + // sqrdmulh v15.8H, v8.8H, v1.H[1] // ............*..................................................... + // mul v8.8H, v8.8H, v1.H[0] // .............*.................................................... + // mls v12.8H, v11.8H, v7.H[0] // ....................*............................................. + // sub v11.8H, v9.8H, v6.8H // .........................*........................................ + // sqrdmulh v24.8H, v16.8H, v0.H[3] // ......*........................................................... + // mul v16.8H, v16.8H, v0.H[2] // .......*.......................................................... + // sub v9.8H, v26.8H, v23.8H // .................................*................................ + // add v26.8H, v26.8H, v23.8H // ..........................................................*....... + // mls v8.8H, v15.8H, v7.H[0] // .................*................................................ + // sqrdmulh v15.8H, v11.8H, v1.H[3] // .............................*.................................... + // mul v11.8H, v11.8H, v1.H[2] // ...............................*.................................. + // sub v6.8H, v3.8H, v4.8H // ................*................................................. + // sub v3.8H, v12.8H, v8.8H // ........................*......................................... + // add v12.8H, v12.8H, v8.8H // ......................................*........................... + // mls v11.8H, v15.8H, v7.H[0] // ...................................*.............................. + // sqrdmulh v8.8H, v6.8H, v1.H[5] // .....................*............................................ + // mls v16.8H, v24.8H, v7.H[0] // ...........*...................................................... + // mul v15.8H, v6.8H, v1.H[4] // ......................*........................................... + // sqrdmulh v24.8H, v3.8H, v0.H[3] // ............................*..................................... + // mul v6.8H, v3.8H, v0.H[2] // ...........................*...................................... + // sqrdmulh v3.8H, v9.8H, v0.H[1] // ....................................*............................. + // mul v9.8H, v9.8H, v0.H[0] // .....................................*............................ + // str q26, [x0], #(16) // .............................................................*.... + // mls v15.8H, v8.8H, v7.H[0] // ..........................*....................................... + // mls v6.8H, v24.8H, v7.H[0] // ................................*................................. + // sub v26.8H, v28.8H, v13.8H // .*................................................................ + // mls v9.8H, v3.8H, v7.H[0] // .........................................*........................ + // sub v8.8H, v11.8H, v15.8H // ........................................*......................... + // sqrdmulh v24.8H, v26.8H, v0.H[5] // .....*............................................................ + // mul v26.8H, v26.8H, v0.H[4] // ....*............................................................. + // add v11.8H, v11.8H, v15.8H // .......................................*.......................... + // sqrdmulh v15.8H, v8.8H, v0.H[5] // ............................................*..................... + // mul v8.8H, v8.8H, v0.H[4] // ...........................................*...................... + // mls v26.8H, v24.8H, v7.H[0] // ..........*....................................................... + // sub v24.8H, v12.8H, v11.8H // ..............................................*................... + // add v12.8H, v12.8H, v11.8H // ..........................................*....................... + // mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + // sqrdmulh v11.8H, v24.8H, v0.H[1] // .................................................*................ + // mul v15.8H, v24.8H, v0.H[0] // ..................................................*............... + // sub v24.8H, v16.8H, v26.8H // ...............*.................................................. + // add v26.8H, v16.8H, v26.8H // ...................................................*.............. + // sub v16.8H, v6.8H, v8.8H // ....................................................*............. + // mls v15.8H, v11.8H, v7.H[0] // ......................................................*........... + // sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................*............................................... + // mul v24.8H, v24.8H, v0.H[0] // ...................*.............................................. + // add v8.8H, v6.8H, v8.8H // .....................................................*............ + // sqrdmulh v6.8H, v16.8H, v0.H[1] // .......................................................*.......... + // mul v16.8H, v16.8H, v0.H[0] // ........................................................*......... + // mls v24.8H, v11.8H, v7.H[0] // .......................*.......................................... + // str q9, [x0, #240] // .............................................*.................... + // mls v16.8H, v6.8H, v7.H[0] // ............................................................*..... + // str q15, [x0, #304] // ...........................................................*...... + // str q24, [x0, #368] // ..................................*............................... + // str q16, [x0, #432] // .................................................................* + // str q12, [x0, #48] // ...............................................*.................. + // str q26, [x0, #112] // ...............................................................*.. + // str q8, [x0, #176] // .........................................................*........ + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_clean.S new file mode 100644 index 0000000000..877a5f689f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/ntt_clean.S @@ -0,0 +1,283 @@ +/// +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// Copyright (c) 2024 The mlkem-native project authors +// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro ct_butterfly a, b, root, idx0, idx1 + mulmodq tmp, \b, \root, \idx0, \idx1 + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro ct_butterfly_v a, b, root, root_twisted + mulmod tmp, \b, \root, \root_twisted + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + // Arguments + in .req x0 // Input/output buffer + r01234_ptr .req x1 // twiddles for layer 0,1,2,3,4 + r56_ptr .req x2 // twiddles for layer 5,6 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + consts .req v7 + q_consts .req q7 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + .text + .global MLKEM_ASM_NAMESPACE(ntt_asm_clean) + +/* Literal pool */ +.p2align 4 +c_consts: + .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + +MLKEM_ASM_NAMESPACE(ntt_asm_clean): + push_stack + ldr q_consts, c_consts + + mov inp, in + mov count, #4 + + load_roots_012 + + .p2align 2 + + // Bounds reasoning: + // - There are 7 layers + // - When passing from layer N to layer N+1, each layer-N value + // is modified through the addition/subtraction of a Montgomery + // product of a twiddle of absolute value < q/2 and a layer-N value. + // - Recalling that for C such that |a| < C * q and |t| + // 0 25 + // |------------------------|---- + ldr q21, [x0, #0] // *............................. + ldr q26, [x0, #64] // ..*........................... + ldr q29, [x0, #128] // ....*......................... + ldr q20, [x0, #192] // ......*....................... + ldr q23, [x0, #256] // ........*..................... + ldr q11, [x0, #448] // ..........*................... + mul v2.8H, v23.8H, v0.H[0] // ............*................. + ldr q17, [x0, #320] // .............*................ + mul v15.8H, v11.8H, v0.H[0] // ...............*.............. + ldr q13, [x0, #384] // ................*............. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q21, [x0, #0] // *.............................. + // ldr q26, [x0, #64] // ..*............................ + // ldr q29, [x0, #128] // ....*.......................... + // ldr q20, [x0, #192] // ......*........................ + // ldr q23, [x0, #256] // ........*...................... + // ldr q17, [x0, #320] // .............*................. + // mul v2.8H, v23.8H, v0.H[0] // ............*.................. + // ldr q11, [x0, #448] // ..........*.................... + // ldr q13, [x0, #384] // ................*.............. + // mul v15.8H, v11.8H, v0.H[0] // ...............*............... + + sub count, count, #1 +1: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.36s + // User time: 2.36s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sqrdmulh v14.8H, v23.8H, v0.H[1] // *................................................................................... + sqrdmulh v23.8H, v17.8H, v0.H[1] // .*.................................................................................. + mul v17.8H, v17.8H, v0.H[0] // ..*................................................................................. + sqrdmulh v28.8H, v13.8H, v0.H[1] // ...*................................................................................ + mls v2.8H, v14.8H, v7.H[0] // ....*............................................................................... + mul v14.8H, v13.8H, v0.H[0] // .....*.............................................................................. + mls v17.8H, v23.8H, v7.H[0] // ......*............................................................................. + sqrdmulh v23.8H, v11.8H, v0.H[1] // .......*............................................................................ + sub v11.8H, v21.8H, v2.8H // ........*........................................................................... + mls v14.8H, v28.8H, v7.H[0] // .........*.......................................................................... + sub v28.8H, v26.8H, v17.8H // ..........*......................................................................... + add v17.8H, v26.8H, v17.8H // ...........*........................................................................ + add v2.8H, v21.8H, v2.8H // ............*....................................................................... + sub v13.8H, v29.8H, v14.8H // .............*...................................................................... + add v14.8H, v29.8H, v14.8H // ..............*..................................................................... + mls v15.8H, v23.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v23.8H, v13.8H, v0.H[5] // ................*................................................................... + mul v13.8H, v13.8H, v0.H[4] // .................*.................................................................. + sqrdmulh v21.8H, v14.8H, v0.H[3] // ..................*................................................................. + sub v26.8H, v20.8H, v15.8H // ...................*................................................................ + add v15.8H, v20.8H, v15.8H // ....................*............................................................... + mls v13.8H, v23.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v23.8H, v26.8H, v0.H[5] // ......................*............................................................. + mul v26.8H, v26.8H, v0.H[4] // .......................*............................................................ + mul v14.8H, v14.8H, v0.H[2] // ........................*........................................................... + sub v29.8H, v11.8H, v13.8H // .........................*.......................................................... + add v11.8H, v11.8H, v13.8H // ..........................*......................................................... + mls v26.8H, v23.8H, v7.H[0] // ...........................*........................................................ + sqrdmulh v23.8H, v15.8H, v0.H[3] // ............................*....................................................... + mul v13.8H, v15.8H, v0.H[2] // .............................*...................................................... + mls v14.8H, v21.8H, v7.H[0] // ..............................*..................................................... + sub v15.8H, v28.8H, v26.8H // ...............................*.................................................... + add v28.8H, v28.8H, v26.8H // ................................*................................................... + mls v13.8H, v23.8H, v7.H[0] // .................................*.................................................. + sub v23.8H, v2.8H, v14.8H // ..................................*................................................. + add v14.8H, v2.8H, v14.8H // ...................................*................................................ + sqrdmulh v2.8H, v28.8H, v1.H[3] // ....................................*............................................... + sub v21.8H, v17.8H, v13.8H // .....................................*.............................................. + add v17.8H, v17.8H, v13.8H // ......................................*............................................. + mul v28.8H, v28.8H, v1.H[2] // .......................................*............................................ + sqrdmulh v13.8H, v21.8H, v1.H[1] // ........................................*........................................... + sqrdmulh v26.8H, v17.8H, v0.H[7] // .........................................*.......................................... + mul v17.8H, v17.8H, v0.H[6] // ..........................................*......................................... + mul v21.8H, v21.8H, v1.H[0] // ...........................................*........................................ + mls v28.8H, v2.8H, v7.H[0] // ............................................*....................................... + sqrdmulh v2.8H, v15.8H, v1.H[5] // .............................................*...................................... + mls v17.8H, v26.8H, v7.H[0] // ..............................................*..................................... + mls v21.8H, v13.8H, v7.H[0] // ...............................................*.................................... + sub v13.8H, v11.8H, v28.8H // ................................................*................................... + add v28.8H, v11.8H, v28.8H // .................................................*.................................. + sub v11.8H, v14.8H, v17.8H // ..................................................*................................. + mul v15.8H, v15.8H, v1.H[4] // ...................................................*................................ + add v14.8H, v14.8H, v17.8H // ....................................................*............................... + sub v17.8H, v23.8H, v21.8H // .....................................................*.............................. + add v23.8H, v23.8H, v21.8H // ......................................................*............................. + mls v15.8H, v2.8H, v7.H[0] // .......................................................*............................ + str q14, [x0], #(16) // ........................................................*........................... + ldr q21, [x0, #0] // .........................................................e.......................... + sub v14.8H, v29.8H, v15.8H // ...........................................................*........................ + add v2.8H, v29.8H, v15.8H // ............................................................*....................... + str q11, [x0, #48] // .............................................................*...................... + ldr q26, [x0, #64] // ..............................................................e..................... + str q23, [x0, #112] // ................................................................*................... + ldr q29, [x0, #128] // .................................................................e.................. + str q17, [x0, #176] // ...................................................................*................ + ldr q20, [x0, #192] // ....................................................................e............... + str q28, [x0, #240] // ......................................................................*............. + ldr q23, [x0, #256] // .......................................................................e............ + str q13, [x0, #304] // .........................................................................*.......... + ldr q17, [x0, #320] // ..........................................................................e......... + str q2, [x0, #368] // ............................................................................*....... + mul v2.8H, v23.8H, v0.H[0] // .............................................................................e...... + str q14, [x0, #432] // ..............................................................................*..... + ldr q11, [x0, #448] // ...............................................................................e.... + ldr q13, [x0, #384] // .................................................................................e.. + mul v15.8H, v11.8H, v0.H[0] // ...................................................................................e + + // ------------------------------------------- cycle (expected) --------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|----- + // ldr q8, [x0, #0] // e..........................'........................................................~..................... + // ldr q9, [x0, #(1*(512/8))] // .....e.....................'.............................................................~................ + // ldr q10, [x0, #(2*(512/8))] // ........e..................'................................................................~............. + // ldr q11, [x0, #(3*(512/8))] // ...........e...............'...................................................................~.......... + // ldr q12, [x0, #(4*(512/8))] // ..............e............'......................................................................~....... + // ldr q13, [x0, #(5*(512/8))] // .................e.........'.........................................................................~.... + // ldr q14, [x0, #(6*(512/8))] // ........................e..'.............................................................................. + // ldr q15, [x0, #(7*(512/8))] // ......................e....'.............................................................................. + // sqrdmulh v27.8h, v12.8h, v0.h[1] // ...........................*.............................................................................. + // mul v24.8h, v12.8h, v0.h[0] // ....................e......'............................................................................~. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...*.......................................................................... + // sub v12.8h, v8.8h, v24.8h // ...........................'.......*...................................................................... + // add v8.8h, v8.8h, v24.8h // ...........................'...........*.................................................................. + // sqrdmulh v27.8h, v13.8h, v0.h[1] // ...........................'*............................................................................. + // mul v24.8h, v13.8h, v0.h[0] // ...........................'.*............................................................................ + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.....*........................................................................ + // sub v13.8h, v9.8h, v24.8h // ...........................'.........*.................................................................... + // add v9.8h, v9.8h, v24.8h // ...........................'..........*................................................................... + // sqrdmulh v27.8h, v14.8h, v0.h[1] // ...........................'..*........................................................................... + // mul v24.8h, v14.8h, v0.h[0] // ...........................'....*......................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'........*..................................................................... + // sub v14.8h, v10.8h, v24.8h // ...........................'............*................................................................. + // add v10.8h, v10.8h, v24.8h // ...........................'.............*................................................................ + // sqrdmulh v27.8h, v15.8h, v0.h[1] // ...........................'......*....................................................................... + // mul v24.8h, v15.8h, v0.h[0] // ..........................e'.............................................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............*............................................................... + // sub v15.8h, v11.8h, v24.8h // ...........................'..................*........................................................... + // add v11.8h, v11.8h, v24.8h // ...........................'...................*.......................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[3] // ...........................'.................*............................................................ + // mul v24.8h, v10.8h, v0.h[2] // ...........................'.......................*...................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................ + // sub v10.8h, v8.8h, v24.8h // ...........................'.................................*............................................ + // add v8.8h, v8.8h, v24.8h // ...........................'..................................*........................................... + // sqrdmulh v27.8h, v11.8h, v0.h[3] // ...........................'...........................*.................................................. + // mul v24.8h, v11.8h, v0.h[2] // ...........................'............................*................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'................................*............................................. + // sub v11.8h, v9.8h, v24.8h // ...........................'....................................*......................................... + // add v9.8h, v9.8h, v24.8h // ...........................'.....................................*........................................ + // sqrdmulh v27.8h, v14.8h, v0.h[5] // ...........................'...............*.............................................................. + // mul v24.8h, v14.8h, v0.h[4] // ...........................'................*............................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'....................*......................................................... + // sub v14.8h, v12.8h, v24.8h // ...........................'........................*..................................................... + // add v12.8h, v12.8h, v24.8h // ...........................'.........................*.................................................... + // sqrdmulh v27.8h, v15.8h, v0.h[5] // ...........................'.....................*........................................................ + // mul v24.8h, v15.8h, v0.h[4] // ...........................'......................*....................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..........................*................................................... + // sub v15.8h, v13.8h, v24.8h // ...........................'..............................*............................................... + // add v13.8h, v13.8h, v24.8h // ...........................'...............................*.............................................. + // sqrdmulh v27.8h, v9.8h, v0.h[7] // ...........................'........................................*..................................... + // mul v24.8h, v9.8h, v0.h[6] // ...........................'.........................................*.................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................................*................................ + // sub v9.8h, v8.8h, v24.8h // ...........................'.................................................*............................ + // add v8.8h, v8.8h, v24.8h // ...........................'...................................................*.......................... + // sqrdmulh v27.8h, v11.8h, v1.h[1] // ...........................'.......................................*...................................... + // mul v24.8h, v11.8h, v1.h[0] // ...........................'..........................................*................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............................................*............................... + // sub v11.8h, v10.8h, v24.8h // ...........................'....................................................*......................... + // add v10.8h, v10.8h, v24.8h // ...........................'.....................................................*........................ + // sqrdmulh v27.8h, v13.8h, v1.h[3] // ...........................'...................................*.......................................... + // mul v24.8h, v13.8h, v1.h[2] // ...........................'......................................*....................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...........................................*.................................. + // sub v13.8h, v12.8h, v24.8h // ...........................'...............................................*.............................. + // add v12.8h, v12.8h, v24.8h // ...........................'................................................*............................. + // sqrdmulh v27.8h, v15.8h, v1.h[5] // ...........................'............................................*................................. + // mul v24.8h, v15.8h, v1.h[4] // ...........................'..................................................*........................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'......................................................*....................... + // sub v15.8h, v14.8h, v24.8h // ..~........................'..........................................................*................... + // add v14.8h, v14.8h, v24.8h // ...~.......................'...........................................................*.................. + // str q8, [x0], #(16) // ...........................'.......................................................*...................... + // str q9, [x0, #(-16 + 1*(512/8))] // ....~......................'............................................................*................. + // str q10, [x0, #(-16 + 2*(512/8))] // .......~...................'...............................................................*.............. + // str q11, [x0, #(-16 + 3*(512/8))] // ..........~................'..................................................................*........... + // str q12, [x0, #(-16 + 4*(512/8))] // .............~.............'.....................................................................*........ + // str q13, [x0, #(-16 + 5*(512/8))] // ................~..........'........................................................................*..... + // str q14, [x0, #(-16 + 6*(512/8))] // ...................~.......'...........................................................................*.. + // str q15, [x0, #(-16 + 7*(512/8))] // .....................~.....'.............................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 66 + // Expected cycles: 67 + // Expected IPC: 0.99 + // + // Cycle bound: 67.0 + // IPC bound: 0.99 + // + // Wall time: 7.51s + // User time: 7.51s + // + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + sqrdmulh v27.8H, v11.8H, v0.H[1] // *.................................................................. + mul v8.8H, v13.8H, v0.H[0] // .*................................................................. + sqrdmulh v22.8H, v13.8H, v0.H[1] // ..*................................................................ + mul v11.8H, v17.8H, v0.H[0] // ...*............................................................... + mls v15.8H, v27.8H, v7.H[0] // ....*.............................................................. + sqrdmulh v28.8H, v17.8H, v0.H[1] // .....*............................................................. + mls v8.8H, v22.8H, v7.H[0] // ......*............................................................ + sqrdmulh v5.8H, v23.8H, v0.H[1] // .......*........................................................... + add v16.8H, v20.8H, v15.8H // ........*.......................................................... + mls v11.8H, v28.8H, v7.H[0] // .........*......................................................... + sub v6.8H, v29.8H, v8.8H // ..........*........................................................ + sqrdmulh v17.8H, v16.8H, v0.H[3] // ...........*....................................................... + mul v23.8H, v16.8H, v0.H[2] // ............*...................................................... + mul v13.8H, v6.8H, v0.H[4] // .............*..................................................... + sqrdmulh v28.8H, v6.8H, v0.H[5] // ..............*.................................................... + mls v2.8H, v5.8H, v7.H[0] // ...............*................................................... + mls v23.8H, v17.8H, v7.H[0] // ................*.................................................. + add v27.8H, v26.8H, v11.8H // .................*................................................. + mls v13.8H, v28.8H, v7.H[0] // ..................*................................................ + sub v9.8H, v21.8H, v2.8H // ...................*............................................... + add v18.8H, v29.8H, v8.8H // ....................*.............................................. + sub v14.8H, v27.8H, v23.8H // .....................*............................................. + add v29.8H, v9.8H, v13.8H // ......................*............................................ + sub v30.8H, v9.8H, v13.8H // .......................*........................................... + mul v28.8H, v14.8H, v1.H[0] // ........................*.......................................... + sqrdmulh v9.8H, v18.8H, v0.H[3] // .........................*......................................... + mul v22.8H, v18.8H, v0.H[2] // ..........................*........................................ + sqrdmulh v17.8H, v14.8H, v1.H[1] // ...........................*....................................... + sub v14.8H, v20.8H, v15.8H // ............................*...................................... + add v24.8H, v21.8H, v2.8H // .............................*..................................... + mls v22.8H, v9.8H, v7.H[0] // ..............................*.................................... + sqrdmulh v9.8H, v14.8H, v0.H[5] // ...............................*................................... + mul v13.8H, v14.8H, v0.H[4] // ................................*.................................. + mls v28.8H, v17.8H, v7.H[0] // .................................*................................. + sub v5.8H, v24.8H, v22.8H // ..................................*................................ + sub v2.8H, v26.8H, v11.8H // ...................................*............................... + mls v13.8H, v9.8H, v7.H[0] // ....................................*.............................. + sub v17.8H, v5.8H, v28.8H // .....................................*............................. + add v14.8H, v5.8H, v28.8H // ......................................*............................ + add v28.8H, v27.8H, v23.8H // .......................................*........................... + str q17, [x0, #192] // ........................................*.......................... + add v17.8H, v2.8H, v13.8H // .........................................*......................... + str q14, [x0, #128] // ..........................................*........................ + sub v13.8H, v2.8H, v13.8H // ...........................................*....................... + sqrdmulh v26.8H, v17.8H, v1.H[3] // ............................................*...................... + mul v15.8H, v17.8H, v1.H[2] // .............................................*..................... + add v5.8H, v24.8H, v22.8H // ..............................................*.................... + sqrdmulh v23.8H, v13.8H, v1.H[5] // ...............................................*................... + mul v13.8H, v13.8H, v1.H[4] // ................................................*.................. + mls v15.8H, v26.8H, v7.H[0] // .................................................*................. + sqrdmulh v14.8H, v28.8H, v0.H[7] // ..................................................*................ + mul v17.8H, v28.8H, v0.H[6] // ...................................................*............... + mls v13.8H, v23.8H, v7.H[0] // ....................................................*.............. + add v6.8H, v29.8H, v15.8H // .....................................................*............. + sub v28.8H, v29.8H, v15.8H // ......................................................*............ + mls v17.8H, v14.8H, v7.H[0] // .......................................................*........... + str q6, [x0, #256] // ........................................................*.......... + add v14.8H, v30.8H, v13.8H // .........................................................*......... + str q28, [x0, #320] // ..........................................................*........ + sub v23.8H, v30.8H, v13.8H // ...........................................................*....... + str q14, [x0, #384] // ............................................................*...... + add v3.8H, v5.8H, v17.8H // .............................................................*..... + str q23, [x0, #448] // ..............................................................*.... + sub v28.8H, v5.8H, v17.8H // ...............................................................*... + str q3, [x0], #(16) // ................................................................*.. + str q28, [x0, #48] // ..................................................................* + + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + // sqrdmulh v14.8H, v23.8H, v0.H[1] // .......*........................................................... + // sqrdmulh v23.8H, v17.8H, v0.H[1] // .....*............................................................. + // mul v17.8H, v17.8H, v0.H[0] // ...*............................................................... + // sqrdmulh v28.8H, v13.8H, v0.H[1] // ..*................................................................ + // mls v2.8H, v14.8H, v7.H[0] // ...............*................................................... + // mul v14.8H, v13.8H, v0.H[0] // .*................................................................. + // mls v17.8H, v23.8H, v7.H[0] // .........*......................................................... + // sqrdmulh v23.8H, v11.8H, v0.H[1] // *.................................................................. + // sub v11.8H, v21.8H, v2.8H // ...................*............................................... + // mls v14.8H, v28.8H, v7.H[0] // ......*............................................................ + // sub v28.8H, v26.8H, v17.8H // ...................................*............................... + // add v17.8H, v26.8H, v17.8H // .................*................................................. + // add v2.8H, v21.8H, v2.8H // .............................*..................................... + // sub v13.8H, v29.8H, v14.8H // ..........*........................................................ + // add v14.8H, v29.8H, v14.8H // ....................*.............................................. + // mls v15.8H, v23.8H, v7.H[0] // ....*.............................................................. + // sqrdmulh v23.8H, v13.8H, v0.H[5] // ..............*.................................................... + // mul v13.8H, v13.8H, v0.H[4] // .............*..................................................... + // sqrdmulh v21.8H, v14.8H, v0.H[3] // .........................*......................................... + // sub v26.8H, v20.8H, v15.8H // ............................*...................................... + // add v15.8H, v20.8H, v15.8H // ........*.......................................................... + // mls v13.8H, v23.8H, v7.H[0] // ..................*................................................ + // sqrdmulh v23.8H, v26.8H, v0.H[5] // ...............................*................................... + // mul v26.8H, v26.8H, v0.H[4] // ................................*.................................. + // mul v14.8H, v14.8H, v0.H[2] // ..........................*........................................ + // sub v29.8H, v11.8H, v13.8H // .......................*........................................... + // add v11.8H, v11.8H, v13.8H // ......................*............................................ + // mls v26.8H, v23.8H, v7.H[0] // ....................................*.............................. + // sqrdmulh v23.8H, v15.8H, v0.H[3] // ...........*....................................................... + // mul v13.8H, v15.8H, v0.H[2] // ............*...................................................... + // mls v14.8H, v21.8H, v7.H[0] // ..............................*.................................... + // sub v15.8H, v28.8H, v26.8H // ...........................................*....................... + // add v28.8H, v28.8H, v26.8H // .........................................*......................... + // mls v13.8H, v23.8H, v7.H[0] // ................*.................................................. + // sub v23.8H, v2.8H, v14.8H // ..................................*................................ + // add v14.8H, v2.8H, v14.8H // ..............................................*.................... + // sqrdmulh v2.8H, v28.8H, v1.H[3] // ............................................*...................... + // sub v21.8H, v17.8H, v13.8H // .....................*............................................. + // add v17.8H, v17.8H, v13.8H // .......................................*........................... + // mul v28.8H, v28.8H, v1.H[2] // .............................................*..................... + // sqrdmulh v13.8H, v21.8H, v1.H[1] // ...........................*....................................... + // sqrdmulh v26.8H, v17.8H, v0.H[7] // ..................................................*................ + // mul v17.8H, v17.8H, v0.H[6] // ...................................................*............... + // mul v21.8H, v21.8H, v1.H[0] // ........................*.......................................... + // mls v28.8H, v2.8H, v7.H[0] // .................................................*................. + // sqrdmulh v2.8H, v15.8H, v1.H[5] // ...............................................*................... + // mls v17.8H, v26.8H, v7.H[0] // .......................................................*........... + // mls v21.8H, v13.8H, v7.H[0] // .................................*................................. + // sub v13.8H, v11.8H, v28.8H // ......................................................*............ + // add v28.8H, v11.8H, v28.8H // .....................................................*............. + // sub v11.8H, v14.8H, v17.8H // ...............................................................*... + // mul v15.8H, v15.8H, v1.H[4] // ................................................*.................. + // add v14.8H, v14.8H, v17.8H // .............................................................*..... + // sub v17.8H, v23.8H, v21.8H // .....................................*............................. + // add v23.8H, v23.8H, v21.8H // ......................................*............................ + // mls v15.8H, v2.8H, v7.H[0] // ....................................................*.............. + // str q14, [x0], #(16) // ................................................................*.. + // sub v14.8H, v29.8H, v15.8H // ...........................................................*....... + // add v2.8H, v29.8H, v15.8H // .........................................................*......... + // str q11, [x0, #48] // ..................................................................* + // str q23, [x0, #112] // ..........................................*........................ + // str q17, [x0, #176] // ........................................*.......................... + // str q28, [x0, #240] // ........................................................*.......... + // str q13, [x0, #304] // ..........................................................*........ + // str q2, [x0, #368] // ............................................................*...... + // str q14, [x0, #432] // ..............................................................*.... + + + mov in, inp + mov count, #8 + + .p2align 2 + // Instructions: 24 + // Expected cycles: 31 + // Expected IPC: 0.77 + // + // Cycle bound: 31.0 + // IPC bound: 0.77 + // + // Wall time: 0.08s + // User time: 0.08s + // + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + ldr q2, [x1], #16 // *.............................. + ldr q14, [x0, #48] // ..*............................ + ldr q1, [x0, #32] // ....*.......................... + mul v17.8H, v14.8H, v2.H[0] // ......*........................ + sqrdmulh v14.8H, v14.8H, v2.H[1] // .......*....................... + mul v8.8H, v1.8H, v2.H[0] // ........*...................... + ldr q23, [x0, #16] // .........*..................... + mls v17.8H, v14.8H, v7.H[0] // ...........*................... + sqrdmulh v1.8H, v1.8H, v2.H[1] // ............*.................. + ldr q30, [x2], #(6*16) // .............*................. + sub v14.8H, v23.8H, v17.8H // ...............*............... + add v10.8H, v23.8H, v17.8H // ................*.............. + mls v8.8H, v1.8H, v7.H[0] // .................*............. + sqrdmulh v1.8H, v14.8H, v2.H[5] // ..................*............ + mul v14.8H, v14.8H, v2.H[4] // ...................*........... + ldr q27, [x0, #0] // ....................*.......... + mul v23.8H, v10.8H, v2.H[2] // ......................*........ + mls v14.8H, v1.8H, v7.H[0] // .......................*....... + sub v1.8H, v27.8H, v8.8H // ........................*...... + ldr q28, [x2, #-64] // .........................*..... + add v12.8H, v1.8H, v14.8H // ...........................*... + sqrdmulh v21.8H, v10.8H, v2.H[3] // ............................*.. + sub v5.8H, v1.8H, v14.8H // .............................*. + ldr q13, [x2, #-16] // ..............................* + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q19, [x0, #48] // ..*............................ + // ldr q1, [x1], #16 // *.............................. + // mul v4.8H, v19.8H, v1.H[0] // ......*........................ + // sqrdmulh v19.8H, v19.8H, v1.H[1] // .......*....................... + // ldr q25, [x0, #16] // .........*..................... + // mls v4.8H, v19.8H, v7.H[0] // ...........*................... + // sub v24.8H, v25.8H, v4.8H // ...............*............... + // add v4.8H, v25.8H, v4.8H // ................*.............. + // sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................*............ + // mul v20.8H, v24.8H, v1.H[4] // ...................*........... + // sqrdmulh v21.8H, v4.8H, v1.H[3] // ............................*.. + // mls v20.8H, v23.8H, v7.H[0] // .......................*....... + // mul v23.8H, v4.8H, v1.H[2] // ......................*........ + // ldr q31, [x0, #32] // ....*.......................... + // mul v8.8H, v31.8H, v1.H[0] // ........*...................... + // sqrdmulh v1.8H, v31.8H, v1.H[1] // ............*.................. + // mls v8.8H, v1.8H, v7.H[0] // .................*............. + // ldr q27, [x0, #0] // ....................*.......... + // sub v10.8H, v27.8H, v8.8H // ........................*...... + // add v12.8H, v10.8H, v20.8H // ...........................*... + // ldr q30, [x2], #(6*16) // .............*................. + // ldr q28, [x2, #-64] // .........................*..... + // sub v5.8H, v10.8H, v20.8H // .............................*. + // ldr q13, [x2, #-16] // ..............................* + + sub count, count, #1 +1: + // Instructions: 71 + // Expected cycles: 82 + // Expected IPC: 0.87 + // + // Cycle bound: 82.0 + // IPC bound: 0.87 + // + // Wall time: 11.93s + // User time: 11.93s + // + // ------------------------------- cycle (expected) --------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + ldr q19, [x0, #112] // e................................................................................. + ldr q1, [x1], #16 // ..e............................................................................... + mls v23.8H, v21.8H, v7.H[0] // ....*............................................................................. + add v6.8H, v27.8H, v8.8H // .....*............................................................................ + mul v4.8H, v19.8H, v1.H[0] // ......e........................................................................... + sqrdmulh v19.8H, v19.8H, v1.H[1] // .......e.......................................................................... + ldr q25, [x0, #80] // ........e......................................................................... + trn1 v11.4S, v12.4S, v5.4S // ..........*....................................................................... + mls v4.8H, v19.8H, v7.H[0] // ...........e...................................................................... + sub v0.8H, v6.8H, v23.8H // ............*..................................................................... + ldr q16, [x2, #-80] // .............*.................................................................... + sub v24.8H, v25.8H, v4.8H // ...............e.................................................................. + add v26.8H, v6.8H, v23.8H // ................*................................................................. + add v4.8H, v25.8H, v4.8H // .................e................................................................ + sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................e............................................................... + mul v20.8H, v24.8H, v1.H[4] // ...................e.............................................................. + sqrdmulh v21.8H, v4.8H, v1.H[3] // ....................e............................................................. + trn1 v27.4S, v26.4S, v0.4S // .....................*............................................................ + trn2 v25.4S, v12.4S, v5.4S // ......................*........................................................... + mls v20.8H, v23.8H, v7.H[0] // .......................e.......................................................... + mul v23.8H, v4.8H, v1.H[2] // ........................e......................................................... + ldr q31, [x0, #96] // .........................e........................................................ + trn2 v12.4S, v26.4S, v0.4S // ...........................*...................................................... + trn2 v19.2D, v27.2D, v11.2D // ............................*..................................................... + mul v8.8H, v31.8H, v1.H[0] // .............................e.................................................... + sqrdmulh v1.8H, v31.8H, v1.H[1] // ..............................e................................................... + trn2 v10.2D, v12.2D, v25.2D // ...............................*.................................................. + sqrdmulh v0.8H, v19.8H, v16.8H // ................................*................................................. + sqrdmulh v18.8H, v10.8H, v16.8H // .................................*................................................ + trn1 v16.2D, v27.2D, v11.2D // ..................................*............................................... + trn1 v2.2D, v12.2D, v25.2D // ...................................*.............................................. + mul v12.8H, v10.8H, v30.8H // ....................................*............................................. + mul v10.8H, v19.8H, v30.8H // .....................................*............................................ + mls v8.8H, v1.8H, v7.H[0] // ......................................e........................................... + ldr q14, [x2, #-48] // .......................................*.......................................... + mls v10.8H, v0.8H, v7.H[0] // .........................................*........................................ + mls v12.8H, v18.8H, v7.H[0] // ..........................................*....................................... + ldr q27, [x0, #64] // ...........................................e...................................... + add v9.8H, v16.8H, v10.8H // .............................................*.................................... + sub v16.8H, v16.8H, v10.8H // ..............................................*................................... + sub v25.8H, v2.8H, v12.8H // ...............................................*.................................. + add v30.8H, v2.8H, v12.8H // ................................................*................................. + sub v10.8H, v27.8H, v8.8H // .................................................e................................ + sqrdmulh v22.8H, v25.8H, v13.8H // ..................................................*............................... + sqrdmulh v13.8H, v30.8H, v14.8H // ...................................................*.............................. + ldr q14, [x2, #-32] // ....................................................*............................. + add v12.8H, v10.8H, v20.8H // ......................................................e........................... + mul v5.8H, v30.8H, v28.8H // .......................................................*.......................... + mul v26.8H, v25.8H, v14.8H // ........................................................*......................... + ldr q30, [x2], #(6*16) // .........................................................e........................ + mls v5.8H, v13.8H, v7.H[0] // ...........................................................*...................... + mls v26.8H, v22.8H, v7.H[0] // ............................................................*..................... + ldr q28, [x2, #-64] // .............................................................e.................... + add v13.8H, v9.8H, v5.8H // ...............................................................*.................. + sub v9.8H, v9.8H, v5.8H // ................................................................*................. + sub v5.8H, v16.8H, v26.8H // .................................................................*................ + add v25.8H, v16.8H, v26.8H // ..................................................................*............... + trn1 v15.4S, v13.4S, v9.4S // ...................................................................*.............. + trn2 v3.4S, v13.4S, v9.4S // ....................................................................*............. + trn1 v13.4S, v25.4S, v5.4S // .....................................................................*............ + trn2 v31.4S, v25.4S, v5.4S // ......................................................................*........... + sub v5.8H, v10.8H, v20.8H // .......................................................................e.......... + trn1 v2.2D, v15.2D, v13.2D // ........................................................................*......... + trn2 v9.2D, v15.2D, v13.2D // .........................................................................*........ + str q2, [x0], #(16*4) // ..........................................................................*....... + trn1 v29.2D, v3.2D, v31.2D // ...........................................................................*...... + str q9, [x0, #-32] // ............................................................................*..... + trn2 v9.2D, v3.2D, v31.2D // .............................................................................*.... + str q29, [x0, #-48] // ..............................................................................*... + ldr q13, [x2, #-16] // ...............................................................................e.. + str q9, [x0, #-16] // .................................................................................* + + // ------------------------------------------------------------------------ cycle (expected) -------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------- + // ldr q8, [x0, #(16*0)] // ...........................................e......................................'..........................................~...................................... + // ldr q9, [x0, #(16*1)] // ........e.........................................................................'.......~......................................................................... + // ldr q10, [x0, #(16*2)] // .........................e........................................................'........................~........................................................ + // ldr q11, [x0, #(16*3)] // e.................................................................................~................................................................................. + // ldr q0, [x1], #16 // ..e...............................................................................'.~............................................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[1] // ..............................e...................................................'.............................~................................................... + // mul v24.8h, v10.8h, v0.h[0] // .............................e....................................................'............................~.................................................... + // mls v24.8h, v27.8h, v7.h[0] // ......................................e...........................................'.....................................~........................................... + // sub v10.8h, v8.8h, v24.8h // .................................................e................................'................................................~................................ + // add v8.8h, v8.8h, v24.8h // .....~............................................................................'....*............................................................................ + // sqrdmulh v27.8h, v11.8h, v0.h[1] // .......e..........................................................................'......~.......................................................................... + // mul v24.8h, v11.8h, v0.h[0] // ......e...........................................................................'.....~........................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........e......................................................................'..........~...................................................................... + // sub v11.8h, v9.8h, v24.8h // ...............e..................................................................'..............~.................................................................. + // add v9.8h, v9.8h, v24.8h // .................e................................................................'................~................................................................ + // sqrdmulh v27.8h, v9.8h, v0.h[3] // ....................e.............................................................'...................~............................................................. + // mul v24.8h, v9.8h, v0.h[2] // ........................e.........................................................'.......................~......................................................... + // mls v24.8h, v27.8h, v7.h[0] // ....~.............................................................................'...*............................................................................. + // sub v9.8h, v8.8h, v24.8h // ............~.....................................................................'...........*..................................................................... + // add v8.8h, v8.8h, v24.8h // ................~.................................................................'...............*................................................................. + // sqrdmulh v27.8h, v11.8h, v0.h[5] // ..................e...............................................................'.................~............................................................... + // mul v24.8h, v11.8h, v0.h[4] // ...................e..............................................................'..................~.............................................................. + // mls v24.8h, v27.8h, v7.h[0] // .......................e..........................................................'......................~.......................................................... + // sub v11.8h, v10.8h, v24.8h // .......................................................................e..........'......................................................................~.......... + // add v10.8h, v10.8h, v24.8h // ......................................................e...........................'.....................................................~........................... + // trn1 v25.4s, v8.4s, v9.4s // .....................~............................................................'....................*............................................................ + // trn2 v26.4s, v8.4s, v9.4s // ...........................~......................................................'..........................*...................................................... + // trn1 v27.4s, v10.4s, v11.4s // ..........~.......................................................................'.........*....................................................................... + // trn2 v28.4s, v10.4s, v11.4s // ......................~...........................................................'.....................*........................................................... + // trn2 v10.2d, v25.2d, v27.2d // ............................~.....................................................'...........................*..................................................... + // trn2 v11.2d, v26.2d, v28.2d // ...............................~..................................................'..............................*.................................................. + // trn1 v8.2d, v25.2d, v27.2d // ..................................~...............................................'.................................*............................................... + // trn1 v9.2d, v26.2d, v28.2d // ...................................~..............................................'..................................*.............................................. + // ldr q0, [x2], #(6*16) // .........................................................e........................'........................................................~........................ + // ldr q4, [x2, #(-6*16 + 1*16)] // .............~....................................................................'............*.................................................................... + // ldr q1, [x2, #(-6*16 + 2*16)] // .............................................................e....................'............................................................~.................... + // ldr q5, [x2, #(-6*16 + 3*16)] // .......................................~..........................................'......................................*.......................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ....................................................~.............................'...................................................*............................. + // ldr q6, [x2, #(-6*16 + 5*16)] // ...............................................................................e..'..............................................................................~.. + // sqrdmulh v27.8h, v10.8h, v4.8h // ................................~.................................................'...............................*................................................. + // mul v24.8h, v10.8h, v0.8h // .....................................~............................................'....................................*............................................ + // mls v24.8h, v27.8h, v7.h[0] // .........................................~........................................'........................................*........................................ + // sub v10.8h, v8.8h, v24.8h // ..............................................~...................................'.............................................*................................... + // add v8.8h, v8.8h, v24.8h // .............................................~....................................'............................................*.................................... + // sqrdmulh v27.8h, v11.8h, v4.8h // .................................~................................................'................................*................................................ + // mul v24.8h, v11.8h, v0.8h // ....................................~.............................................'...................................*............................................. + // mls v24.8h, v27.8h, v7.h[0] // ..........................................~.......................................'.........................................*....................................... + // sub v11.8h, v9.8h, v24.8h // ...............................................~..................................'..............................................*.................................. + // add v9.8h, v9.8h, v24.8h // ................................................~.................................'...............................................*................................. + // sqrdmulh v27.8h, v9.8h, v5.8h // ...................................................~..............................'..................................................*.............................. + // mul v24.8h, v9.8h, v1.8h // .......................................................~..........................'......................................................*.......................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................................................~......................'..........................................................*...................... + // sub v9.8h, v8.8h, v24.8h // ................................................................~.................'...............................................................*................. + // add v8.8h, v8.8h, v24.8h // ...............................................................~..................'..............................................................*.................. + // sqrdmulh v27.8h, v11.8h, v6.8h // ..................................................~...............................'.................................................*............................... + // mul v24.8h, v11.8h, v2.8h // ........................................................~.........................'.......................................................*......................... + // mls v24.8h, v27.8h, v7.h[0] // ............................................................~.....................'...........................................................*..................... + // sub v11.8h, v10.8h, v24.8h // .................................................................~................'................................................................*................ + // add v10.8h, v10.8h, v24.8h // ..................................................................~...............'.................................................................*............... + // trn1 v25.4s, v8.4s, v9.4s // ...................................................................~..............'..................................................................*.............. + // trn2 v26.4s, v8.4s, v9.4s // ....................................................................~.............'...................................................................*............. + // trn1 v27.4s, v10.4s, v11.4s // .....................................................................~............'....................................................................*............ + // trn2 v28.4s, v10.4s, v11.4s // ......................................................................~...........'.....................................................................*........... + // trn2 v10.2d, v25.2d, v27.2d // .........................................................................~........'........................................................................*........ + // trn2 v11.2d, v26.2d, v28.2d // .............................................................................~....'............................................................................*.... + // trn1 v8.2d, v25.2d, v27.2d // ........................................................................~.........'.......................................................................*......... + // trn1 v9.2d, v26.2d, v28.2d // ...........................................................................~......'..........................................................................*...... + // str q8, [x0], #(16*4) // ..........................................................................~.......'.........................................................................*....... + // str q9, [x0, #(-16*3)] // ..............................................................................~...'.............................................................................*... + // str q10, [x0, #(-16*2)] // ............................................................................~.....'...........................................................................*..... + // str q11, [x0, #(-16*1)] // .................................................................................~'................................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 47 + // Expected cycles: 52 + // Expected IPC: 0.90 + // + // Cycle bound: 52.0 + // IPC bound: 0.90 + // + // Wall time: 5.32s + // User time: 5.32s + // + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + mls v23.8H, v21.8H, v7.H[0] // *................................................... + add v14.8H, v27.8H, v8.8H // .*.................................................. + ldr q1, [x2, #-32] // ..*................................................. + add v17.8H, v14.8H, v23.8H // ....*............................................... + sub v23.8H, v14.8H, v23.8H // .....*.............................................. + trn2 v11.4S, v12.4S, v5.4S // ......*............................................. + trn1 v27.4S, v12.4S, v5.4S // .......*............................................ + trn2 v2.4S, v17.4S, v23.4S // ........*........................................... + ldr q26, [x2, #-80] // .........*.......................................... + trn2 v14.2D, v2.2D, v11.2D // ...........*........................................ + trn1 v15.4S, v17.4S, v23.4S // ............*....................................... + mul v5.8H, v14.8H, v30.8H // .............*...................................... + sqrdmulh v23.8H, v14.8H, v26.8H // ..............*..................................... + trn2 v17.2D, v15.2D, v27.2D // ...............*.................................... + trn1 v14.2D, v2.2D, v11.2D // ................*................................... + mul v21.8H, v17.8H, v30.8H // .................*.................................. + mls v5.8H, v23.8H, v7.H[0] // ..................*................................. + sqrdmulh v17.8H, v17.8H, v26.8H // ...................*................................ + ldr q2, [x2, #-48] // ....................*............................... + sub v23.8H, v14.8H, v5.8H // ......................*............................. + add v14.8H, v14.8H, v5.8H // .......................*............................ + mls v21.8H, v17.8H, v7.H[0] // ........................*........................... + mul v1.8H, v23.8H, v1.8H // .........................*.......................... + sqrdmulh v17.8H, v23.8H, v13.8H // ..........................*......................... + mul v23.8H, v14.8H, v28.8H // ...........................*........................ + sqrdmulh v14.8H, v14.8H, v2.8H // ............................*....................... + trn1 v28.2D, v15.2D, v27.2D // .............................*...................... + mls v1.8H, v17.8H, v7.H[0] // ..............................*..................... + sub v11.8H, v28.8H, v21.8H // ...............................*.................... + mls v23.8H, v14.8H, v7.H[0] // ................................*................... + add v17.8H, v28.8H, v21.8H // .................................*.................. + sub v14.8H, v11.8H, v1.8H // ..................................*................. + add v1.8H, v11.8H, v1.8H // ...................................*................ + sub v28.8H, v17.8H, v23.8H // ....................................*............... + add v2.8H, v17.8H, v23.8H // .....................................*.............. + trn1 v23.4S, v1.4S, v14.4S // ......................................*............. + trn2 v14.4S, v1.4S, v14.4S // .......................................*............ + trn2 v17.4S, v2.4S, v28.4S // ........................................*........... + trn1 v28.4S, v2.4S, v28.4S // .........................................*.......... + trn2 v1.2D, v17.2D, v14.2D // ...........................................*........ + trn1 v14.2D, v17.2D, v14.2D // ............................................*....... + str q1, [x0, #48] // .............................................*...... + trn2 v1.2D, v28.2D, v23.2D // ..............................................*..... + str q14, [x0, #16] // ...............................................*.... + trn1 v14.2D, v28.2D, v23.2D // ................................................*... + str q1, [x0, #32] // .................................................*.. + str q14, [x0], #(16*4) // ...................................................* + + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + // mls v23.8H, v21.8H, v7.H[0] // *................................................... + // add v6.8H, v27.8H, v8.8H // .*.................................................. + // trn1 v11.4S, v12.4S, v5.4S // .......*............................................ + // sub v0.8H, v6.8H, v23.8H // .....*.............................................. + // ldr q16, [x2, #-80] // .........*.......................................... + // add v26.8H, v6.8H, v23.8H // ....*............................................... + // trn1 v27.4S, v26.4S, v0.4S // ............*....................................... + // trn2 v25.4S, v12.4S, v5.4S // ......*............................................. + // trn2 v12.4S, v26.4S, v0.4S // ........*........................................... + // trn2 v19.2D, v27.2D, v11.2D // ...............*.................................... + // trn2 v10.2D, v12.2D, v25.2D // ...........*........................................ + // sqrdmulh v0.8H, v19.8H, v16.8H // ...................*................................ + // sqrdmulh v18.8H, v10.8H, v16.8H // ..............*..................................... + // trn1 v16.2D, v27.2D, v11.2D // .............................*...................... + // trn1 v2.2D, v12.2D, v25.2D // ................*................................... + // mul v12.8H, v10.8H, v30.8H // .............*...................................... + // mul v10.8H, v19.8H, v30.8H // .................*.................................. + // ldr q14, [x2, #-48] // ....................*............................... + // mls v10.8H, v0.8H, v7.H[0] // ........................*........................... + // mls v12.8H, v18.8H, v7.H[0] // ..................*................................. + // add v9.8H, v16.8H, v10.8H // .................................*.................. + // sub v16.8H, v16.8H, v10.8H // ...............................*.................... + // sub v25.8H, v2.8H, v12.8H // ......................*............................. + // add v30.8H, v2.8H, v12.8H // .......................*............................ + // sqrdmulh v22.8H, v25.8H, v13.8H // ..........................*......................... + // sqrdmulh v13.8H, v30.8H, v14.8H // ............................*....................... + // ldr q14, [x2, #-32] // ..*................................................. + // mul v5.8H, v30.8H, v28.8H // ...........................*........................ + // mul v26.8H, v25.8H, v14.8H // .........................*.......................... + // mls v5.8H, v13.8H, v7.H[0] // ................................*................... + // mls v26.8H, v22.8H, v7.H[0] // ..............................*..................... + // add v13.8H, v9.8H, v5.8H // .....................................*.............. + // sub v9.8H, v9.8H, v5.8H // ....................................*............... + // sub v5.8H, v16.8H, v26.8H // ..................................*................. + // add v25.8H, v16.8H, v26.8H // ...................................*................ + // trn1 v15.4S, v13.4S, v9.4S // .........................................*.......... + // trn2 v3.4S, v13.4S, v9.4S // ........................................*........... + // trn1 v13.4S, v25.4S, v5.4S // ......................................*............. + // trn2 v31.4S, v25.4S, v5.4S // .......................................*............ + // trn1 v2.2D, v15.2D, v13.2D // ................................................*... + // trn2 v9.2D, v15.2D, v13.2D // ..............................................*..... + // str q2, [x0], #(16*4) // ...................................................* + // trn1 v29.2D, v3.2D, v31.2D // ............................................*....... + // str q9, [x0, #-32] // .................................................*.. + // trn2 v9.2D, v3.2D, v31.2D // ...........................................*........ + // str q29, [x0, #-48] // ...............................................*.... + // str q9, [x0, #-16] // .............................................*...... + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/opt_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/opt_impl.h new file mode 100644 index 0000000000..b226740261 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/opt_impl.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +#define NTT_BOUND_NATIVE (6 * MLKEM_Q) +static INLINE void ntt_native(poly *data) +{ + ntt_asm_opt(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_opt(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_opt(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_opt(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_opt(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_opt( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_opt(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/optimize.sh b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/optimize.sh new file mode 100755 index 0000000000..9d43dfa80d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/optimize.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env sh +# Copyright (c) 2024 The mlkem-native project authors +# SPDX-License-Identifier: Apache-2.0 + +set -e + +TARGET_NAME="Cortex-A55" +TARGET=Arm_Cortex_A55 + +echo "* polyvec_basemul_acc_montgomery_cached, K=2, ${TARGET_NAME}" + +cp polyvec_clean.S polyvec_opt.S + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k2_clean,polyvec_basemul_acc_montgomery_cached_asm_k2_opt \ + -l k2_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=3, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k3_clean,polyvec_basemul_acc_montgomery_cached_asm_k3_opt \ + -l k3_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=4, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k4_clean,polyvec_basemul_acc_montgomery_cached_asm_k4_opt \ + -l k4_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c sw_pipelining.allow_post \ + -c constraints.stalls_first_attempt=64 + +cp poly_clean.S poly_opt.S + +echo "* poly_reduce, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_reduce_asm_clean,poly_reduce_asm_opt \ + -l loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_mulcache_compute, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_mulcache_compute_asm_clean,poly_mulcache_compute_asm_opt \ + -l mulcache_compute_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_tomont, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_tomont_asm_clean,poly_tomont_asm_opt \ + -l poly_tomont_asm_loop \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * ntt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + ntt_clean.S -o ntt_opt.S \ + -r ntt_asm_clean,ntt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * intt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + intt_clean.S -o intt_opt.S \ + -r intt_asm_clean,intt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_clean.S new file mode 100644 index 0000000000..f70a402215 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_clean.S @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 +loop_start: + ldr q_data, [ptr], #64 + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-64] + + ldr q_data, [ptr, #-48] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-48] + + ldr q_data, [ptr, #-32] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-32] + + ldr q_data, [ptr, #-16] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-16] + + subs count, count, #1 + cbnz count, loop_start + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 +mulcache_compute_loop_start: + ldr q_tmp0, [data_ptr], #32 + ldr q_tmp1, [data_ptr, #-16] + ldr q_zeta, [zeta_ptr], #16 + ldr q_zeta_twisted, [zeta_twisted_ptr], #16 + + // The mulcache of a polynomial a + b*X in Fq[X^2-zeta] is b*zeta; + // Since tmp0 || tmp1 represents multiple such polynomails as + // (a0,b0,a1,b1,...), extract only the odd elements. + uzp2 data_odd.8h, tmp0.8h, tmp1.8h + mulmod dst, data_odd, zeta, zeta_twisted + + str q_dst, [cache_ptr], #16 + + subs count, count, #1 + cbnz count, mulcache_compute_loop_start + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean): + + mov count, #16 +poly_tobytes_asm_clean_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_clean_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 +poly_tomont_asm_loop: + + ldr q_data, [src], #64 + mulmod res, data, factor, factor_t + str q_res, [src, #-64] + + ldr q_data, [src, #-48] + mulmod res, data, factor, factor_t + str q_res, [src, #-48] + + ldr q_data, [src, #-32] + mulmod res, data, factor, factor_t + str q_res, [src, #-32] + + ldr q_data, [src, #-16] + mulmod res, data, factor, factor_t + str q_res, [src, #-16] + + sub count, count, #1 + cbnz count, poly_tomont_asm_loop + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_opt.S new file mode 100644 index 0000000000..e58ee77c46 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/poly_opt.S @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 + // Instructions: 15 + // Expected cycles: 22 + // Expected IPC: 0.68 + + // Cycle bound: 22.0 + // IPC bound: 0.68 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q21, [x0, #32] // *............................. + ldr q23, [x0, #48] // ..*........................... + sqdmulh v7.8H, v21.8H, v4.H[0] // ....*......................... + sqdmulh v30.8H, v23.8H, v4.H[0] // ......*....................... + srshr v7.8H, v7.8H, #11 // ........*..................... + srshr v30.8H, v30.8H, #11 // ..........*................... + mls v21.8H, v7.8H, v3.H[0] // ...........*.................. + mls v23.8H, v30.8H, v3.H[0] // .............*................ + ldr q5, [x0, #16] // ..............*............... + sshr v7.8H, v21.8H, #15 // ................*............. + sshr v30.8H, v23.8H, #15 // .................*............ + and v7.16B, v3.16B, v7.16B // ..................*........... + add v21.8H, v21.8H, v7.8H // ...................*.......... + and v7.16B, v3.16B, v30.16B // ....................*......... + add v16.8H, v23.8H, v7.8H // .....................*........ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q30, [x0, #32] // *.............................. + // sqdmulh v22.8H, v30.8H, v4.H[0] // ....*.......................... + // ldr q2, [x0, #48] // ..*............................ + // srshr v19.8H, v22.8H, #11 // ........*...................... + // mls v30.8H, v19.8H, v3.H[0] // ...........*................... + // sqdmulh v25.8H, v2.8H, v4.H[0] // ......*........................ + // sshr v31.8H, v30.8H, #15 // ................*.............. + // srshr v25.8H, v25.8H, #11 // ..........*.................... + // and v18.16B, v3.16B, v31.16B // ..................*............ + // mls v2.8H, v25.8H, v3.H[0] // .............*................. + // add v21.8H, v30.8H, v18.8H // ...................*........... + // ldr q5, [x0, #16] // ..............*................ + // sshr v18.8H, v2.8H, #15 // .................*............. + // and v27.16B, v3.16B, v18.16B // ....................*.......... + // add v16.8H, v2.8H, v27.8H // .....................*......... + + sub count, count, #1 +1: + // Instructions: 32 + // Expected cycles: 36 + // Expected IPC: 0.89 + + // Cycle bound: 36.0 + // IPC bound: 0.89 + + // Wall time: 1.05s + // User time: 1.05s + + // -------- cycle (expected) ---------> + // 0 25 + // |------------------------|---------- + ldr q6, [x0], #64 // *................................... + ldr q30, [x0, #32] // ..e................................. + sqdmulh v31.8H, v6.8H, v4.H[0] // ....*............................... + sqdmulh v29.8H, v5.8H, v4.H[0] // .....*.............................. + sqdmulh v22.8H, v30.8H, v4.H[0] // ......e............................. + str q16, [x0, #-16] // .......*............................ + srshr v20.8H, v31.8H, #11 // ........*........................... + srshr v28.8H, v29.8H, #11 // .........*.......................... + str q21, [x0, #-32] // ..........*......................... + mls v6.8H, v20.8H, v3.H[0] // ...........*........................ + mls v5.8H, v28.8H, v3.H[0] // ............*....................... + ldr q2, [x0, #48] // .............e...................... + sshr v31.8H, v6.8H, #15 // ...............*.................... + srshr v19.8H, v22.8H, #11 // ................e................... + and v22.16B, v3.16B, v31.16B // .................*.................. + add v0.8H, v6.8H, v22.8H // ..................*................. + mls v30.8H, v19.8H, v3.H[0] // ...................e................ + sshr v26.8H, v5.8H, #15 // ....................*............... + sqdmulh v25.8H, v2.8H, v4.H[0] // .....................e.............. + and v17.16B, v3.16B, v26.16B // ......................*............. + add v1.8H, v5.8H, v17.8H // .......................*............ + sshr v31.8H, v30.8H, #15 // ........................e........... + srshr v25.8H, v25.8H, #11 // .........................e.......... + str q1, [x0, #-48] // ..........................*......... + and v18.16B, v3.16B, v31.16B // ...........................e........ + mls v2.8H, v25.8H, v3.H[0] // ............................e....... + add v21.8H, v30.8H, v18.8H // .............................e...... + ldr q5, [x0, #16] // ..............................e..... + sshr v18.8H, v2.8H, #15 // ................................e... + str q0, [x0, #-64] // .................................*.. + and v27.16B, v3.16B, v18.16B // ..................................e. + add v16.8H, v2.8H, v27.8H // ...................................e + + // ------------------------ cycle (expected) -------------------------> + // 0 25 50 + // |------------------------|------------------------|----------------- + // ldr q0, [x0], #64 // ..................................*................................. + // sqdmulh v1.8h, v0.8h, v4.h[0] // ..~...............................'...*............................. + // srshr v1.8h, v1.8h, #11 // ......~...........................'.......*......................... + // mls v0.8h, v1.8h, v3.h[0] // .........~........................'..........*...................... + // sshr v2.8h, v0.8h, #15 // .............~....................'..............*.................. + // and v2.16b, v3.16b, v2.16b // ...............~..................'................*................ + // add v0.8h, v0.8h, v2.8h // ................~.................'.................*............... + // str q0, [x0, #-64] // ...............................~..'................................* + // ldr q0, [x0, #-48] // ............................e.....'.............................~... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...~..............................'....*............................ + // srshr v1.8h, v1.8h, #11 // .......~..........................'........*........................ + // mls v0.8h, v1.8h, v3.h[0] // ..........~.......................'...........*..................... + // sshr v2.8h, v0.8h, #15 // ..................~...............'...................*............. + // and v2.16b, v3.16b, v2.16b // ....................~.............'.....................*........... + // add v0.8h, v0.8h, v2.8h // .....................~............'......................*.......... + // str q0, [x0, #-48] // ........................~.........'.........................*....... + // ldr q0, [x0, #-32] // e.................................'.~............................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ....e.............................'.....~........................... + // srshr v1.8h, v1.8h, #11 // ..............e...................'...............~................. + // mls v0.8h, v1.8h, v3.h[0] // .................e................'..................~.............. + // sshr v2.8h, v0.8h, #15 // ......................e...........'.......................~......... + // and v2.16b, v3.16b, v2.16b // .........................e........'..........................~...... + // add v0.8h, v0.8h, v2.8h // ...........................e......'............................~.... + // str q0, [x0, #-32] // ........~.........................'.........*....................... + // ldr q0, [x0, #-16] // ...........e......................'............~.................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...................e..............'....................~............ + // srshr v1.8h, v1.8h, #11 // .......................e..........'........................~........ + // mls v0.8h, v1.8h, v3.h[0] // ..........................e.......'...........................~..... + // sshr v2.8h, v0.8h, #15 // ..............................e...'...............................~. + // and v2.16b, v3.16b, v2.16b // ................................e.'................................. + // add v0.8h, v0.8h, v2.8h // .................................e'................................. + // str q0, [x0, #-16] // .....~............................'......*.......................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 17 + // Expected cycles: 23 + // Expected IPC: 0.74 + + // Cycle bound: 23.0 + // IPC bound: 0.74 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + sqdmulh v20.8H, v5.8H, v4.H[0] // *............................. + ldr q24, [x0], #64 // .*............................ + str q21, [x0, #-32] // ...*.......................... + srshr v20.8H, v20.8H, #11 // ....*......................... + sqdmulh v25.8H, v24.8H, v4.H[0] // .....*........................ + str q16, [x0, #-16] // ......*....................... + mls v5.8H, v20.8H, v3.H[0] // .......*...................... + srshr v20.8H, v25.8H, #11 // .........*.................... + sshr v2.8H, v5.8H, #15 // ...........*.................. + mls v24.8H, v20.8H, v3.H[0] // ............*................. + and v20.16B, v3.16B, v2.16B // .............*................ + add v31.8H, v5.8H, v20.8H // ..............*............... + sshr v20.8H, v24.8H, #15 // ................*............. + str q31, [x0, #-48] // .................*............ + and v31.16B, v3.16B, v20.16B // ..................*........... + add v24.8H, v24.8H, v31.8H // ...................*.......... + str q24, [x0, #-64] // ......................*....... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q6, [x0], #64 // .*............................. + // sqdmulh v31.8H, v6.8H, v4.H[0] // .....*......................... + // sqdmulh v29.8H, v5.8H, v4.H[0] // *.............................. + // str q16, [x0, #-16] // ......*........................ + // srshr v20.8H, v31.8H, #11 // .........*..................... + // srshr v28.8H, v29.8H, #11 // ....*.......................... + // str q21, [x0, #-32] // ...*........................... + // mls v6.8H, v20.8H, v3.H[0] // ............*.................. + // mls v5.8H, v28.8H, v3.H[0] // .......*....................... + // sshr v31.8H, v6.8H, #15 // ................*.............. + // and v22.16B, v3.16B, v31.16B // ..................*............ + // add v0.8H, v6.8H, v22.8H // ...................*........... + // sshr v26.8H, v5.8H, #15 // ...........*................... + // and v17.16B, v3.16B, v26.16B // .............*................. + // add v1.8H, v5.8H, v17.8H // ..............*................ + // str q1, [x0, #-48] // .................*............. + // str q0, [x0, #-64] // ......................*........ + + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 + // Instructions: 7 + // Expected cycles: 12 + // Expected IPC: 0.58 + + // Cycle bound: 12.0 + // IPC bound: 0.58 + + // Wall time: 0.01s + // User time: 0.01s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q1, [x1, #16] // *............................. + ldr q27, [x1], #32 // ..*........................... + ldr q23, [x2], #16 // ....*......................... + uzp2 v27.8H, v27.8H, v1.8H // ......*....................... + ldr q1, [x3], #16 // .......*...................... + mul v2.8H, v27.8H, v23.8H // .........*.................... + sqrdmulh v27.8H, v27.8H, v1.8H // ...........*.................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q29, [x1, #16] // *.............................. + // ldr q21, [x2], #16 // ....*.......................... + // ldr q27, [x1], #32 // ..*............................ + // ldr q7, [x3], #16 // .......*....................... + // uzp2 v28.8H, v27.8H, v29.8H // ......*........................ + // mul v2.8H, v28.8H, v21.8H // .........*..................... + // sqrdmulh v27.8H, v28.8H, v7.8H // ...........*................... + + sub count, count, #1 +1: + // Instructions: 9 + // Expected cycles: 13 + // Expected IPC: 0.69 + + // Cycle bound: 13.0 + // IPC bound: 0.69 + + // Wall time: 0.09s + // User time: 0.09s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q29, [x1, #16] // e............................. + ldr q21, [x2], #16 // ..e........................... + mls v2.8H, v27.8H, v6.H[0] // ....*......................... + ldr q27, [x1], #32 // .....e........................ + ldr q7, [x3], #16 // .......e...................... + uzp2 v28.8H, v27.8H, v29.8H // .........e.................... + str q2, [x0], #16 // ..........*................... + mul v2.8H, v28.8H, v21.8H // ...........e.................. + sqrdmulh v27.8H, v28.8H, v7.8H // ............e................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q3, [x1], #32 // .....e.......'....~.......'.... + // ldr q4, [x1, #-16] // e............~............~.... + // ldr q1, [x2], #16 // ..e..........'.~..........'.~.. + // ldr q2, [x3], #16 // .......e.....'......~.....'.... + // uzp2 v0.8h, v3.8h, v4.8h // .........e...'........~...'.... + // sqrdmulh v3.8h, v0.8h, v2.8h // ............e'...........~'.... + // mul v5.8h, v0.8h, v1.8h // ...........e.'..........~.'.... + // mls v5.8h, v3.8h, v6.h[0] // ....~........'...*........'.... + // str q5, [x0], #16 // ..........~..'.........*..'.... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 2 + // Expected cycles: 5 + // Expected IPC: 0.40 + + // Cycle bound: 5.0 + // IPC bound: 0.40 + + // Wall time: 0.00s + // User time: 0.00s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v2.8H, v27.8H, v6.H[0] // *............................. + str q2, [x0], #16 // ....*......................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v2.8H, v27.8H, v6.H[0] // *.............................. + // str q2, [x0], #16 // ....*.......................... + + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt): + + mov count, #16 +poly_tobytes_asm_opt_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_opt_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 + // Instructions: 5 + // Expected cycles: 7 + // Expected IPC: 0.71 + // + // Cycle bound: 7.0 + // IPC bound: 0.71 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x0, #48] // *............................. + ldr q23, [x0, #16] // ..*........................... + mul v17.8H, v26.8H, v2.8H // ....*......................... + sqrdmulh v7.8H, v26.8H, v3.8H // .....*........................ + ldr q27, [x0, #32] // ......*....................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q7, [x0, #48] // *.............................. + // ldr q23, [x0, #16] // ..*............................ + // mul v17.8H, v7.8H, v2.8H // ....*.......................... + // sqrdmulh v7.8H, v7.8H, v3.8H // .....*......................... + // ldr q27, [x0, #32] // ......*........................ + + sub count, count, #1 +1: + // Instructions: 20 + // Expected cycles: 24 + // Expected IPC: 0.83 + // + // Cycle bound: 24.0 + // IPC bound: 0.83 + // + // Wall time: 0.73s + // User time: 0.73s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v5.8H, v23.8H, v3.8H // .*............................ + ldr q7, [x0], #64 // ..*........................... + str q17, [x0, #-16] // ....*......................... + sqrdmulh v29.8H, v27.8H, v3.8H // .....*........................ + sqrdmulh v19.8H, v7.8H, v3.8H // ......*....................... + mul v25.8H, v23.8H, v2.8H // .......*...................... + mul v0.8H, v7.8H, v2.8H // ........*..................... + mul v26.8H, v27.8H, v2.8H // .........*.................... + ldr q7, [x0, #48] // ..........e................... + mls v25.8H, v5.8H, v4.H[0] // ............*................. + ldr q23, [x0, #16] // .............e................ + mls v26.8H, v29.8H, v4.H[0] // ...............*.............. + mls v0.8H, v19.8H, v4.H[0] // ................*............. + str q25, [x0, #-48] // .................*............ + mul v17.8H, v7.8H, v2.8H // ..................e........... + sqrdmulh v7.8H, v7.8H, v3.8H // ...................e.......... + str q0, [x0, #-64] // ....................*......... + ldr q27, [x0, #32] // .....................e........ + str q26, [x0, #-32] // .......................*...... + + // --------- cycle (expected) ----------> + // 0 25 + // |------------------------|------------ + // ldr q0, [x0], #64 // ..............'.*..................... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'.....*................. + // mul v1.8h, v0.8h, v2.8h // ..............'.......*............... + // mls v1.8h, v6.8h, v4.h[0] // ......~.......'...............*....... + // str q1, [x0, #-64] // ..........~...'...................*... + // ldr q0, [x0, #-48] // ...e..........'............~.......... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'*...................... + // mul v1.8h, v0.8h, v2.8h // ..............'......*................ + // mls v1.8h, v6.8h, v4.h[0] // ..~...........'...........*........... + // str q1, [x0, #-48] // .......~......'................*...... + // ldr q0, [x0, #-32] // ...........e..'....................~.. + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'....*.................. + // mul v1.8h, v0.8h, v2.8h // ..............'........*.............. + // mls v1.8h, v6.8h, v4.h[0] // .....~........'..............*........ + // str q1, [x0, #-32] // .............~'......................* + // ldr q0, [x0, #-16] // e.............'.........~............. + // sqrdmulh v6.8h, v0.8h, v3.8h // .........e....'..................~.... + // mul v1.8h, v0.8h, v2.8h // ........e.....'.................~..... + // mls v1.8h, v6.8h, v4.h[0] // ..............*....................... + // str q1, [x0, #-16] // ..............'...*................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 15 + // Expected cycles: 18 + // Expected IPC: 0.83 + // + // Cycle bound: 18.0 + // IPC bound: 0.83 + // + // Wall time: 0.07s + // User time: 0.07s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v7.8H, v23.8H, v3.8H // .*............................ + mul v26.8H, v23.8H, v2.8H // ..*........................... + sqrdmulh v25.8H, v27.8H, v3.8H // ...*.......................... + ldr q23, [x0], #64 // ....*......................... + mul v27.8H, v27.8H, v2.8H // ......*....................... + mls v26.8H, v7.8H, v4.H[0] // .......*...................... + sqrdmulh v7.8H, v23.8H, v3.8H // ........*..................... + mul v23.8H, v23.8H, v2.8H // .........*.................... + str q17, [x0, #-16] // ..........*................... + mls v27.8H, v25.8H, v4.H[0] // ...........*.................. + str q26, [x0, #-48] // ............*................. + mls v23.8H, v7.8H, v4.H[0] // .............*................ + str q27, [x0, #-32] // ...............*.............. + str q23, [x0, #-64] // .................*............ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v17.8H, v7.8H, v4.H[0] // *.............................. + // sqrdmulh v5.8H, v23.8H, v3.8H // .*............................. + // ldr q7, [x0], #64 // ....*.......................... + // str q17, [x0, #-16] // ..........*.................... + // sqrdmulh v29.8H, v27.8H, v3.8H // ...*........................... + // sqrdmulh v19.8H, v7.8H, v3.8H // ........*...................... + // mul v25.8H, v23.8H, v2.8H // ..*............................ + // mul v0.8H, v7.8H, v2.8H // .........*..................... + // mul v26.8H, v27.8H, v2.8H // ......*........................ + // mls v25.8H, v5.8H, v4.H[0] // .......*....................... + // mls v26.8H, v29.8H, v4.H[0] // ...........*................... + // mls v0.8H, v19.8H, v4.H[0] // .............*................. + // str q25, [x0, #-48] // ............*.................. + // str q0, [x0, #-64] // .................*............. + // str q26, [x0, #-32] // ...............*............... + + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_clean.S new file mode 100644 index 0000000000..99fb05de5d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_clean.S @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +// +// AArch64 re-implementation of the asymmetric base multiplication from: +// +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. +// +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) +k2_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k2_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) +k3_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k3_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + // + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) +k4_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr, b3_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k4_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_opt.S new file mode 100644 index 0000000000..16ed77c3fc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/polyvec_opt.S @@ -0,0 +1,1584 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// AArch64 re-implementation of the asymmetric base multiplication from: + +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. + +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 94 + // Expected IPC: 0.80 + + // Cycle bound: 94.0 + // IPC bound: 0.80 + + // Wall time: 1.49s + // User time: 1.49s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q9, [x4], #32 // *.......................................................................... + ldr q5, [x4, #-16] // ......*.................................................................... + ldr q11, [x5], #32 // .*......................................................................... + uzp1 v23.8H, v9.8H, v5.8H // .........*................................................................. + uzp2 v9.8H, v9.8H, v5.8H // .....................*..................................................... + ldr q5, [x2], #32 // ..*........................................................................ + ldr q7, [x5, #-16] // ..............*............................................................ + ldr q21, [x2, #-16] // ...*....................................................................... + uzp2 v10.8H, v11.8H, v7.8H // .................*......................................................... + uzp1 v11.8H, v11.8H, v7.8H // ..................*........................................................ + uzp1 v7.8H, v5.8H, v21.8H // ....*...................................................................... + uzp2 v5.8H, v5.8H, v21.8H // .....*..................................................................... + ldr q21, [x1], #32 // .......*................................................................... + ldr q25, [x1, #-16] // ........*.................................................................. + ld1 {v6.8H}, [x3], #16 // ............................*.............................................. + uzp1 v26.8H, v21.8H, v25.8H // ..........*................................................................ + uzp2 v21.8H, v21.8H, v25.8H // ...........*............................................................... + smull v25.4S, v26.4H, v5.4H // ............*.............................................................. + smull2 v5.4S, v26.8H, v5.8H // .............*............................................................. + smull v19.4S, v26.4H, v7.4H // ..........................*................................................ + smull2 v26.4S, v26.8H, v7.8H // ..............................*............................................ + smlal v25.4S, v21.4H, v7.4H // ...............*........................................................... + smlal2 v5.4S, v21.8H, v7.8H // ................*.......................................................... + smlal v19.4S, v21.4H, v6.4H // ...................................*....................................... + smlal2 v26.4S, v21.8H, v6.8H // .................................*......................................... + smlal v25.4S, v23.4H, v10.4H // ...................*....................................................... + smlal2 v5.4S, v23.8H, v10.8H // ....................*...................................................... + smlal v19.4S, v23.4H, v11.4H // ......................................*.................................... + smlal2 v26.4S, v23.8H, v11.8H // ....................................*...................................... + ld1 {v23.8H}, [x6], #16 // ........................*.................................................. + smlal v25.4S, v9.4H, v11.4H // ......................*.................................................... + smlal2 v5.4S, v9.8H, v11.8H // .......................*................................................... + smlal2 v26.4S, v9.8H, v23.8H // .......................................*................................... + smlal v19.4S, v9.4H, v23.4H // .........................................*................................. + ldr q9, [x4], #32 // ...............................*........................................... + uzp1 v11.8H, v25.8H, v5.8H // .........................*................................................. + uzp1 v23.8H, v19.8H, v26.8H // .............................................*............................. + mul v11.8H, v11.8H, v2.8H // ...........................*............................................... + mul v23.8H, v23.8H, v2.8H // ..............................................*............................ + ldr q7, [x5], #32 // ................................*.......................................... + smlal2 v5.4S, v11.8H, v0.8H // .............................*............................................. + smlal v25.4S, v11.4H, v0.4H // ..................................*........................................ + ldr q11, [x2], #32 // .....................................*..................................... + ldr q21, [x2, #-16] // ........................................*.................................. + ldr q6, [x4, #-16] // ...............................................*........................... + uzp1 v17.8H, v11.8H, v21.8H // ...........................................*............................... + ldr q10, [x1], #32 // ................................................*.......................... + ldr q29, [x1, #-16] // .................................................*......................... + uzp2 v11.8H, v11.8H, v21.8H // ............................................*.............................. + uzp1 v13.8H, v9.8H, v6.8H // ...................................................*....................... + uzp1 v3.8H, v10.8H, v29.8H // ....................................................*...................... + uzp2 v10.8H, v10.8H, v29.8H // .....................................................*..................... + smull v12.4S, v3.4H, v11.4H // ......................................................*.................... + smull2 v11.4S, v3.8H, v11.8H // .......................................................*................... + ldr q21, [x5, #-16] // ........................................................*.................. + smlal v12.4S, v10.4H, v17.4H // .........................................................*................. + smlal2 v11.4S, v10.8H, v17.8H // ..........................................................*................ + uzp2 v29.8H, v7.8H, v21.8H // ...........................................................*............... + uzp1 v15.8H, v7.8H, v21.8H // ............................................................*.............. + smlal v12.4S, v13.4H, v29.4H // .............................................................*............. + smlal2 v11.4S, v13.8H, v29.8H // ..............................................................*............ + uzp2 v28.8H, v9.8H, v6.8H // ...............................................................*........... + smlal2 v26.4S, v23.8H, v0.8H // ..................................................*........................ + smlal v12.4S, v28.4H, v15.4H // .................................................................*......... + smlal2 v11.4S, v28.8H, v15.8H // ..................................................................*........ + smlal v19.4S, v23.4H, v0.4H // ................................................................*.......... + uzp2 v27.8H, v25.8H, v5.8H // ..........................................*................................ + smull v23.4S, v3.4H, v17.4H // ......................................................................*.... + uzp1 v9.8H, v12.8H, v11.8H // .....................................................................*..... + uzp2 v19.8H, v19.8H, v26.8H // ....................................................................*...... + mul v14.8H, v9.8H, v2.8H // .......................................................................*... + ld1 {v22.8H}, [x6], #16 // ...................................................................*....... + zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + smlal2 v11.4S, v14.8H, v0.8H // ..........................................................................* + ld1 {v4.8H}, [x3], #16 // .........................................................................*. + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q18, [x4], #32 // *.......................................................................... + // ldr q30, [x5], #32 // ..*........................................................................ + // ldr q8, [x2], #32 // .....*..................................................................... + // ldr q9, [x2, #-16] // .......*................................................................... + // uzp1 v17.8H, v8.8H, v9.8H // ..........*................................................................ + // uzp2 v4.8H, v8.8H, v9.8H // ...........*............................................................... + // ldr q19, [x4, #-16] // .*......................................................................... + // ldr q29, [x1], #32 // ............*.............................................................. + // ldr q12, [x1, #-16] // .............*............................................................. + // uzp1 v13.8H, v18.8H, v19.8H // ...*....................................................................... + // uzp1 v3.8H, v29.8H, v12.8H // ...............*........................................................... + // uzp2 v10.8H, v29.8H, v12.8H // ................*.......................................................... + // smull v12.4S, v3.4H, v4.4H // .................*......................................................... + // smull2 v11.4S, v3.8H, v4.8H // ..................*........................................................ + // ldr q5, [x5, #-16] // ......*.................................................................... + // smlal v12.4S, v10.4H, v17.4H // .....................*..................................................... + // smlal2 v11.4S, v10.8H, v17.8H // ......................*.................................................... + // uzp2 v14.8H, v30.8H, v5.8H // ........*.................................................................. + // uzp1 v15.8H, v30.8H, v5.8H // .........*................................................................. + // smlal v12.4S, v13.4H, v14.4H // .........................*................................................. + // smlal2 v11.4S, v13.8H, v14.8H // ..........................*................................................ + // uzp2 v28.8H, v18.8H, v19.8H // ....*...................................................................... + // smlal v12.4S, v28.4H, v15.4H // ..............................*............................................ + // smlal2 v11.4S, v28.8H, v15.8H // ...............................*........................................... + // ld1 {v22.8H}, [x6], #16 // .............................*............................................. + // uzp1 v1.8H, v12.8H, v11.8H // ...................................*....................................... + // smull v23.4S, v3.4H, v17.4H // ...................*....................................................... + // mul v14.8H, v1.8H, v2.8H // .....................................*..................................... + // ld1 {v4.8H}, [x3], #16 // ..............*............................................................ + // smlal2 v11.4S, v14.8H, v0.8H // ........................................*.................................. + // smull2 v20.4S, v3.8H, v17.8H // ....................*...................................................... + // ldr q18, [x4], #32 // ..................................*........................................ + // ldr q30, [x5], #32 // .......................................*................................... + // smlal2 v20.4S, v10.8H, v4.8H // ........................*.................................................. + // smlal v12.4S, v14.4H, v0.4H // .........................................*................................. + // smlal v23.4S, v10.4H, v4.4H // .......................*................................................... + // smlal2 v20.4S, v13.8H, v15.8H // ............................*.............................................. + // ldr q8, [x2], #32 // ..........................................*................................ + // smlal v23.4S, v13.4H, v15.4H // ...........................*............................................... + // smlal2 v20.4S, v28.8H, v22.8H // ................................*.......................................... + // ldr q9, [x2, #-16] // ...........................................*............................... + // smlal v23.4S, v28.4H, v22.4H // .................................*......................................... + // uzp2 v27.8H, v12.8H, v11.8H // ..................................................................*........ + // uzp1 v17.8H, v8.8H, v9.8H // .............................................*............................. + // uzp2 v4.8H, v8.8H, v9.8H // ................................................*.......................... + // uzp1 v5.8H, v23.8H, v20.8H // ....................................*...................................... + // mul v31.8H, v5.8H, v2.8H // ......................................*.................................... + // ldr q19, [x4, #-16] // ............................................*.............................. + // ldr q29, [x1], #32 // ..............................................*............................ + // ldr q12, [x1, #-16] // ...............................................*........................... + // smlal2 v20.4S, v31.8H, v0.8H // ..............................................................*............ + // uzp1 v13.8H, v18.8H, v19.8H // .................................................*......................... + // uzp1 v3.8H, v29.8H, v12.8H // ..................................................*........................ + // uzp2 v10.8H, v29.8H, v12.8H // ...................................................*....................... + // smull v12.4S, v3.4H, v4.4H // ....................................................*...................... + // smull2 v11.4S, v3.8H, v4.8H // .....................................................*..................... + // ldr q5, [x5, #-16] // ......................................................*.................... + // smlal v12.4S, v10.4H, v17.4H // .......................................................*................... + // smlal2 v11.4S, v10.8H, v17.8H // ........................................................*.................. + // uzp2 v14.8H, v30.8H, v5.8H // .........................................................*................. + // uzp1 v15.8H, v30.8H, v5.8H // ..........................................................*................ + // smlal v12.4S, v13.4H, v14.4H // ...........................................................*............... + // smlal2 v11.4S, v13.8H, v14.8H // ............................................................*.............. + // uzp2 v28.8H, v18.8H, v19.8H // .............................................................*............. + // smlal v23.4S, v31.4H, v0.4H // .................................................................*......... + // smlal v12.4S, v28.4H, v15.4H // ...............................................................*........... + // smlal2 v11.4S, v28.8H, v15.8H // ................................................................*.......... + // ld1 {v22.8H}, [x6], #16 // .......................................................................*... + // uzp2 v19.8H, v23.8H, v20.8H // .....................................................................*..... + // uzp1 v1.8H, v12.8H, v11.8H // ....................................................................*...... + // smull v23.4S, v3.4H, v17.4H // ...................................................................*....... + // mul v14.8H, v1.8H, v2.8H // ......................................................................*.... + // zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + // ld1 {v4.8H}, [x3], #16 // ..........................................................................* + // smlal2 v11.4S, v14.8H, v0.8H // .........................................................................*. + + sub count, count, #2 +1: + // Instructions: 48 + // Expected cycles: 58 + // Expected IPC: 0.83 + + // Cycle bound: 58.0 + // IPC bound: 0.83 + + // Wall time: 6.39s + // User time: 6.39s + + // -------------- original position --------------> + // 0 25 + // |------------------------|---------------------- + smull2 v20.4S, v3.8H, v17.8H // ..........*..................................... + ldr q18, [x4], #32 // .................e.............................. + ldr q30, [x5], #32 // .....................e.......................... + smlal2 v20.4S, v10.8H, v4.8H // ............*................................... + smlal v12.4S, v14.4H, v0.4H // .........................................*...... + smlal v23.4S, v10.4H, v4.4H // ...........*.................................... + str q9, [x0, #16] // ...............................................l + smlal2 v20.4S, v13.8H, v15.8H // ...........................*.................... + ldr q8, [x2], #32 // ....e........................................... + smlal v23.4S, v13.4H, v15.4H // ..........................*..................... + smlal2 v20.4S, v28.8H, v22.8H // .............................*.................. + zip1 v26.8H, v19.8H, v27.8H // ............................................l... + ldr q9, [x2, #-16] // .....e.......................................... + smlal v23.4S, v28.4H, v22.4H // ............................*................... + uzp2 v27.8H, v12.8H, v11.8H // ...........................................*.... + uzp1 v17.8H, v8.8H, v9.8H // ......e......................................... + uzp2 v4.8H, v8.8H, v9.8H // .......e........................................ + uzp1 v5.8H, v23.8H, v20.8H // ..................................*............. + str q26, [x0], #32 // ..............................................l. + mul v31.8H, v5.8H, v2.8H // ...................................*............ + ldr q19, [x4, #-16] // ..................e............................. + ldr q29, [x1], #32 // e............................................... + ldr q12, [x1, #-16] // .e.............................................. + smlal2 v20.4S, v31.8H, v0.8H // .....................................*.......... + uzp1 v13.8H, v18.8H, v19.8H // ...................e............................ + uzp1 v3.8H, v29.8H, v12.8H // ..e............................................. + uzp2 v10.8H, v29.8H, v12.8H // ...e............................................ + smull v12.4S, v3.4H, v4.4H // .............e.................................. + smull2 v11.4S, v3.8H, v4.8H // ..............e................................. + ldr q5, [x5, #-16] // ......................e......................... + smlal v12.4S, v10.4H, v17.4H // ...............e................................ + smlal2 v11.4S, v10.8H, v17.8H // ................e............................... + uzp2 v14.8H, v30.8H, v5.8H // ........................e....................... + uzp1 v15.8H, v30.8H, v5.8H // .......................e........................ + smlal v12.4S, v13.4H, v14.4H // ..............................e................. + smlal2 v11.4S, v13.8H, v14.8H // ...............................e................ + uzp2 v28.8H, v18.8H, v19.8H // ....................e........................... + smlal v23.4S, v31.4H, v0.4H // ....................................*........... + smlal v12.4S, v28.4H, v15.4H // ................................e............... + smlal2 v11.4S, v28.8H, v15.8H // .................................e.............. + ld1 {v22.8H}, [x6], #16 // .........................e...................... + uzp2 v19.8H, v23.8H, v20.8H // ......................................*......... + uzp1 v1.8H, v12.8H, v11.8H // .......................................e........ + smull v23.4S, v3.4H, v17.4H // .........e...................................... + mul v14.8H, v1.8H, v2.8H // ........................................e....... + zip2 v9.8H, v19.8H, v27.8H // .............................................*.. + ld1 {v4.8H}, [x3], #16 // ........e....................................... + smlal2 v11.4S, v14.8H, v0.8H // ..........................................e..... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q12, [x1], #32 // ....................e..........................'....................~..........................'.................. + // ldr q13, [x1, #-16] // .....................e.........................'.....................~.........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................e......................'........................~......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // .........................e.....................'.........................~.....................'.................. + // ldr q12, [x2], #32 // .......e.......................................'.......~.......................................'.......~.......... + // ldr q13, [x2, #-16] // ...........e...................................'...........~...................................'...........~...... + // uzp1 v5.8h, v12.8h, v13.8h // ..............e................................'..............~................................'..............~... + // uzp2 v6.8h, v12.8h, v13.8h // ...............e...............................'...............~...............................'...............~.. + // ld1 {v7.8h}, [x3], #16 // .............................................e.'.............................................~.'.................. + // smull v8.4s, v3.4h, v5.4h // ..........................................e....'..........................................~....'.................. + // smull2 v10.4s, v3.8h, v5.8h // ...............................................*...............................................~.................. + // smlal v8.4s, v4.4h, v7.4h // ....~..........................................'....*..........................................'....~............. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................'..*............................................'..~............... + // smull v9.4s, v3.4h, v6.4h // ..........................e....................'..........................~....................'.................. + // smull2 v11.4s, v3.8h, v6.8h // ...........................e...................'...........................~...................'.................. + // smlal v9.4s, v4.4h, v5.4h // .............................e.................'.............................~.................'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ..............................e................'..............................~................'.................. + // ldr q12, [x4], #32 // e..............................................'~..............................................'~................. + // ldr q13, [x4, #-16] // ...................e...........................'...................~...........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // .......................e.......................'.......................~.......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // ...................................e...........'...................................~...........'.................. + // ldr q12, [x5], #32 // .e.............................................'.~.............................................'.~................ + // ldr q13, [x5, #-16] // ............................e..................'............................~..................'.................. + // uzp1 v5.8h, v12.8h, v13.8h // ................................e..............'................................~..............'.................. + // uzp2 v6.8h, v12.8h, v13.8h // ...............................e...............'...............................~...............'.................. + // ld1 {v7.8h}, [x6], #16 // .......................................e.......'.......................................~.......'.................. + // smlal v8.4s, v3.4h, v5.4h // ........~......................................'........*......................................'........~......... + // smlal2 v10.4s, v3.8h, v5.8h // ......~........................................'......*........................................'......~........... + // smlal v8.4s, v4.4h, v7.4h // ............~..................................'............*..................................'............~..... + // smlal2 v10.4s, v4.8h, v7.8h // .........~.....................................'.........*.....................................'.........~........ + // smlal v9.4s, v3.4h, v6.4h // .................................e.............'.................................~.............'.................. + // smlal2 v11.4s, v3.8h, v6.8h // ..................................e............'..................................~............'.................. + // smlal v9.4s, v4.4h, v5.4h // .....................................e.........'.....................................~.........'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ......................................e........'......................................~........'.................. + // uzp1 v28.8h, v8.8h, v10.8h // ................~..............................'................*..............................'................~. + // mul v28.8h, v28.8h, v2.8h // ..................~............................'..................*............................'.................. + // smlal v8.4s, v28.4h, v0.4h // ....................................~..........'....................................*..........'.................. + // smlal2 v10.4s, v28.8h, v0.8h // ......................~........................'......................*........................'.................. + // uzp2 v26.8h, v8.8h, v10.8h // ........................................~......'........................................*......'.................. + // uzp1 v28.8h, v9.8h, v11.8h // .........................................e.....'.........................................~.....'.................. + // mul v28.8h, v28.8h, v2.8h // ...........................................e...'...........................................~...'.................. + // smlal v9.4s, v28.4h, v0.4h // ...~...........................................'...*...........................................'...~.............. + // smlal2 v11.4s, v28.8h, v0.8h // ..............................................e'..............................................~'.................. + // uzp2 v27.8h, v9.8h, v11.8h // .............~.................................'.............*.................................'.............~.... + // zip1 v12.8h, v26.8h, v27.8h // ..........~....................................'..........~....................................'..........l....... + // zip2 v13.8h, v26.8h, v27.8h // ............................................~..'............................................*..'.................. + // str q12, [x0], #32 // .................~.............................'.................~.............................'.................l + // str q13, [x0, #-16] // .....~.........................................'.....~.........................................'.....l............ + + sub count, count, #1 + cbnz count, 1b + // Instructions: 21 + // Expected cycles: 35 + // Expected IPC: 0.60 + + // Cycle bound: 35.0 + // IPC bound: 0.60 + + // Wall time: 0.08s + // User time: 0.08s + + // ----- original position -----> + // 0 25 + // |------------------------|---- + smull2 v5.4S, v3.8H, v17.8H // *............................. + smlal v12.4S, v14.4H, v0.4H // ..*........................... + smlal v23.4S, v10.4H, v4.4H // ...*.......................... + str q9, [x0, #16] // ....*......................... + smlal2 v5.4S, v10.8H, v4.8H // .*............................ + uzp2 v11.8H, v12.8H, v11.8H // ..........*................... + zip1 v9.8H, v19.8H, v27.8H // ........*..................... + smlal v23.4S, v13.4H, v15.4H // ......*....................... + smlal2 v5.4S, v13.8H, v15.8H // .....*........................ + str q9, [x0], #32 // ............*................. + smlal v23.4S, v28.4H, v22.4H // .........*.................... + smlal2 v5.4S, v28.8H, v22.8H // .......*...................... + uzp1 v9.8H, v23.8H, v5.8H // ...........*.................. + mul v9.8H, v9.8H, v2.8H // .............*................ + smlal2 v5.4S, v9.8H, v0.8H // ..............*............... + smlal v23.4S, v9.4H, v0.4H // ...............*.............. + uzp2 v9.8H, v23.8H, v5.8H // ................*............. + zip2 v5.8H, v9.8H, v11.8H // .................*............ + zip1 v9.8H, v9.8H, v11.8H // ...................*.......... + str q5, [x0, #16] // ..................*........... + str q9, [x0], #32 // ....................*......... + + // -------- new position --------> + // 0 25 + // |------------------------|----- + // smull2 v20.4S, v3.8H, v17.8H // *.............................. + // smlal2 v20.4S, v10.8H, v4.8H // ....*.......................... + // smlal v12.4S, v14.4H, v0.4H // .*............................. + // smlal v23.4S, v10.4H, v4.4H // ..*............................ + // str q9, [x0, #16] // ...*........................... + // smlal2 v20.4S, v13.8H, v15.8H // ........*...................... + // smlal v23.4S, v13.4H, v15.4H // .......*....................... + // smlal2 v20.4S, v28.8H, v22.8H // ...........*................... + // zip1 v26.8H, v19.8H, v27.8H // ......*........................ + // smlal v23.4S, v28.4H, v22.4H // ..........*.................... + // uzp2 v27.8H, v12.8H, v11.8H // .....*......................... + // uzp1 v5.8H, v23.8H, v20.8H // ............*.................. + // str q26, [x0], #32 // .........*..................... + // mul v31.8H, v5.8H, v2.8H // .............*................. + // smlal2 v20.4S, v31.8H, v0.8H // ..............*................ + // smlal v23.4S, v31.4H, v0.4H // ...............*............... + // uzp2 v19.8H, v23.8H, v20.8H // ................*.............. + // zip2 v9.8H, v19.8H, v27.8H // .................*............. + // str q9, [x0, #16] // ...................*........... + // zip1 v26.8H, v19.8H, v27.8H // ..................*............ + // str q26, [x0], #32 // ....................*.......... + + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 103 + // Expected IPC: 0.73 + + // Cycle bound: 103.0 + // IPC bound: 0.73 + + // Wall time: 0.94s + // User time: 0.94s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q7, [x2, #16] // *.......................................................................... + ldr q20, [x2], #32 // ..*........................................................................ + ldr q15, [x1, #16] // .*......................................................................... + uzp1 v8.8H, v20.8H, v7.8H // ...............*........................................................... + uzp2 v7.8H, v20.8H, v7.8H // ................*.......................................................... + ld1 {v20.8H}, [x3], #16 // ...*....................................................................... + ldr q30, [x1], #32 // ..............*............................................................ + ldr q11, [x4], #32 // ....*...................................................................... + uzp1 v16.8H, v30.8H, v15.8H // .................*......................................................... + uzp2 v15.8H, v30.8H, v15.8H // ..................*........................................................ + smull v30.4S, v16.4H, v7.4H // ...................*....................................................... + smull2 v7.4S, v16.8H, v7.8H // ....................*...................................................... + smull v9.4S, v16.4H, v8.4H // .....................*..................................................... + smull2 v16.4S, v16.8H, v8.8H // ......................*.................................................... + smlal v30.4S, v15.4H, v8.4H // .......................*................................................... + smlal2 v7.4S, v15.8H, v8.8H // ........................*.................................................. + smlal v9.4S, v15.4H, v20.4H // .........................*................................................. + smlal2 v16.4S, v15.8H, v20.8H // ..........................*................................................ + ldr q20, [x4, #-16] // .....*..................................................................... + ldr q15, [x5], #32 // ......*.................................................................... + uzp1 v8.8H, v11.8H, v20.8H // ...........................*............................................... + uzp2 v20.8H, v11.8H, v20.8H // ............................*.............................................. + ldr q11, [x5, #-16] // .......*................................................................... + ld1 {v27.8H}, [x6], #16 // ........*.................................................................. + uzp1 v10.8H, v15.8H, v11.8H // .............................*............................................. + uzp2 v15.8H, v15.8H, v11.8H // ..............................*............................................ + smlal v9.4S, v8.4H, v10.4H // ...............................*........................................... + smlal2 v16.4S, v8.8H, v10.8H // ................................*.......................................... + smlal v30.4S, v8.4H, v15.4H // .................................*......................................... + smlal2 v7.4S, v8.8H, v15.8H // ..................................*........................................ + smlal v9.4S, v20.4H, v27.4H // ...................................*....................................... + smlal2 v16.4S, v20.8H, v27.8H // ....................................*...................................... + smlal v30.4S, v20.4H, v10.4H // .....................................*..................................... + smlal2 v7.4S, v20.8H, v10.8H // ......................................*.................................... + ldr q20, [x7], #32 // .........*................................................................. + ldr q15, [x7, #-16] // ..........*................................................................ + ldr q8, [x8], #32 // ...........*............................................................... + uzp1 v11.8H, v20.8H, v15.8H // .......................................*................................... + uzp2 v20.8H, v20.8H, v15.8H // ........................................*.................................. + ldr q15, [x8, #-16] // ............*.............................................................. + ld1 {v27.8H}, [x9], #16 // .............*............................................................. + uzp1 v10.8H, v8.8H, v15.8H // .........................................*................................. + uzp2 v15.8H, v8.8H, v15.8H // ..........................................*................................ + smlal v9.4S, v11.4H, v10.4H // ...........................................*............................... + smlal2 v16.4S, v11.8H, v10.8H // ............................................*.............................. + smlal v30.4S, v11.4H, v15.4H // .............................................*............................. + smlal2 v7.4S, v11.8H, v15.8H // ..............................................*............................ + smlal v9.4S, v20.4H, v27.4H // ...............................................*........................... + smlal2 v16.4S, v20.8H, v27.8H // ................................................*.......................... + smlal v30.4S, v20.4H, v10.4H // .................................................*......................... + smlal2 v7.4S, v20.8H, v10.8H // ..................................................*........................ + ldr q15, [x2], #32 // ...............................................................*........... + uzp1 v20.8H, v9.8H, v16.8H // ....................................................*...................... + uzp1 v8.8H, v30.8H, v7.8H // .....................................................*..................... + mul v20.8H, v20.8H, v2.8H // ......................................................*.................... + mul v8.8H, v8.8H, v2.8H // .......................................................*................... + ldr q21, [x4], #32 // .................................................................*......... + smlal v9.4S, v20.4H, v0.4H // ........................................................*.................. + smlal2 v16.4S, v20.8H, v0.8H // .........................................................*................. + smlal v30.4S, v8.4H, v0.4H // ..........................................................*................ + smlal2 v7.4S, v8.8H, v0.8H // ...........................................................*............... + ldr q6, [x4, #-16] // ..................................................................*........ + uzp2 v27.8H, v9.8H, v16.8H // ............................................................*.............. + uzp2 v10.8H, v30.8H, v7.8H // .............................................................*............. + ldr q16, [x2, #-16] // ...................................................*....................... + ldr q30, [x1, #16] // ..............................................................*............ + ld1 {v9.8H}, [x3], #16 // ................................................................*.......... + ldr q1, [x5], #32 // ...................................................................*....... + ldr q12, [x5, #-16] // ....................................................................*...... + ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + ldr q19, [x7], #32 // ......................................................................*.... + ldr q31, [x7, #-16] // .......................................................................*... + ldr q17, [x8], #32 // ........................................................................*.. + ldr q18, [x8, #-16] // .........................................................................*. + ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q16, [x2, #16] // *.......................................................................... + // ldr q30, [x1, #16] // ..*........................................................................ + // ldr q15, [x2], #32 // .*......................................................................... + // ld1 {v9.8H}, [x3], #16 // .....*..................................................................... + // ldr q21, [x4], #32 // .......*................................................................... + // ldr q6, [x4, #-16] // ..................*........................................................ + // ldr q1, [x5], #32 // ...................*....................................................... + // ldr q12, [x5, #-16] // ......................*.................................................... + // ld1 {v24.8H}, [x6], #16 // .......................*................................................... + // ldr q19, [x7], #32 // ..................................*........................................ + // ldr q31, [x7, #-16] // ...................................*....................................... + // ldr q17, [x8], #32 // ....................................*...................................... + // ldr q18, [x8, #-16] // .......................................*................................... + // ld1 {v25.8H}, [x9], #16 // ........................................*.................................. + // ldr q20, [x1], #32 // ......*.................................................................... + // uzp1 v7.8H, v15.8H, v16.8H // ...*....................................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ....*...................................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ........*.................................................................. + // uzp2 v20.8H, v20.8H, v30.8H // .........*................................................................. + // smull v30.4S, v8.4H, v15.4H // ..........*................................................................ + // smull2 v15.4S, v8.8H, v15.8H // ...........*............................................................... + // smull v11.4S, v8.4H, v7.4H // ............*.............................................................. + // smull2 v8.4S, v8.8H, v7.8H // .............*............................................................. + // smlal v30.4S, v20.4H, v7.4H // ..............*............................................................ + // smlal2 v15.4S, v20.8H, v7.8H // ...............*........................................................... + // smlal v11.4S, v20.4H, v9.4H // ................*.......................................................... + // smlal2 v8.4S, v20.8H, v9.8H // .................*......................................................... + // uzp1 v7.8H, v21.8H, v6.8H // ....................*...................................................... + // uzp2 v20.8H, v21.8H, v6.8H // .....................*..................................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........................*.................................................. + // uzp2 v9.8H, v1.8H, v12.8H // .........................*................................................. + // smlal v11.4S, v7.4H, v16.4H // ..........................*................................................ + // smlal2 v8.4S, v7.8H, v16.8H // ...........................*............................................... + // smlal v30.4S, v7.4H, v9.4H // ............................*.............................................. + // smlal2 v15.4S, v7.8H, v9.8H // .............................*............................................. + // smlal v11.4S, v20.4H, v24.4H // ..............................*............................................ + // smlal2 v8.4S, v20.8H, v24.8H // ...............................*........................................... + // smlal v30.4S, v20.4H, v16.4H // ................................*.......................................... + // smlal2 v15.4S, v20.8H, v16.8H // .................................*......................................... + // uzp1 v7.8H, v19.8H, v31.8H // .....................................*..................................... + // uzp2 v20.8H, v19.8H, v31.8H // ......................................*.................................... + // uzp1 v16.8H, v17.8H, v18.8H // .........................................*................................. + // uzp2 v9.8H, v17.8H, v18.8H // ..........................................*................................ + // smlal v11.4S, v7.4H, v16.4H // ...........................................*............................... + // smlal2 v8.4S, v7.8H, v16.8H // ............................................*.............................. + // smlal v30.4S, v7.4H, v9.4H // .............................................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ..............................................*............................ + // smlal v11.4S, v20.4H, v25.4H // ...............................................*........................... + // smlal2 v8.4S, v20.8H, v25.8H // ................................................*.......................... + // smlal v30.4S, v20.4H, v16.4H // .................................................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ..................................................*........................ + // ldr q16, [x2, #16] // ................................................................*.......... + // uzp1 v7.8H, v11.8H, v8.8H // ....................................................*...................... + // uzp1 v20.8H, v30.8H, v15.8H // .....................................................*..................... + // mul v7.8H, v7.8H, v2.8H // ......................................................*.................... + // mul v20.8H, v20.8H, v2.8H // .......................................................*................... + // smlal v11.4S, v7.4H, v0.4H // .........................................................*................. + // smlal2 v8.4S, v7.8H, v0.8H // ..........................................................*................ + // smlal v30.4S, v20.4H, v0.4H // ...........................................................*............... + // smlal2 v15.4S, v20.8H, v0.8H // ............................................................*.............. + // uzp2 v27.8H, v11.8H, v8.8H // ..............................................................*............ + // uzp2 v10.8H, v30.8H, v15.8H // ...............................................................*........... + // ldr q30, [x1, #16] // .................................................................*......... + // ldr q15, [x2], #32 // ...................................................*....................... + // ld1 {v9.8H}, [x3], #16 // ..................................................................*........ + // ldr q21, [x4], #32 // ........................................................*.................. + // ldr q6, [x4, #-16] // .............................................................*............. + // ldr q1, [x5], #32 // ...................................................................*....... + // ldr q12, [x5, #-16] // ....................................................................*...... + // ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + // ldr q19, [x7], #32 // ......................................................................*.... + // ldr q31, [x7, #-16] // .......................................................................*... + // ldr q17, [x8], #32 // ........................................................................*.. + // ldr q18, [x8, #-16] // .........................................................................*. + // ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + sub count, count, #2 +1: + // Instructions: 65 + // Expected cycles: 80 + // Expected IPC: 0.81 + + // Cycle bound: 80.0 + // IPC bound: 0.81 + + // Wall time: 11.64s + // User time: 11.64s + + // ---------------------- original position -----------------------> + // 0 25 50 + // |------------------------|------------------------|-------------- + ldr q20, [x1], #32 // *................................................................ + uzp1 v7.8H, v15.8H, v16.8H // ......*.......................................................... + uzp2 v15.8H, v15.8H, v16.8H // .......*......................................................... + uzp1 v8.8H, v20.8H, v30.8H // ..*.............................................................. + uzp2 v20.8H, v20.8H, v30.8H // ...*............................................................. + smull v30.4S, v8.4H, v15.4H // .............*................................................... + smull2 v15.4S, v8.8H, v15.8H // ..............*.................................................. + smull v11.4S, v8.4H, v7.4H // .........*....................................................... + smull2 v8.4S, v8.8H, v7.8H // ..........*...................................................... + smlal v30.4S, v20.4H, v7.4H // ...............*................................................. + smlal2 v15.4S, v20.8H, v7.8H // ................*................................................ + smlal v11.4S, v20.4H, v9.4H // ...........*..................................................... + smlal2 v8.4S, v20.8H, v9.8H // ............*.................................................... + uzp1 v7.8H, v21.8H, v6.8H // ...................*............................................. + uzp2 v20.8H, v21.8H, v6.8H // ....................*............................................ + uzp1 v16.8H, v1.8H, v12.8H // .......................*......................................... + uzp2 v9.8H, v1.8H, v12.8H // ........................*........................................ + smlal v11.4S, v7.4H, v16.4H // ..........................*...................................... + smlal2 v8.4S, v7.8H, v16.8H // ...........................*..................................... + smlal v30.4S, v7.4H, v9.4H // ..............................*.................................. + smlal2 v15.4S, v7.8H, v9.8H // ...............................*................................. + smlal v11.4S, v20.4H, v24.4H // ............................*.................................... + smlal2 v8.4S, v20.8H, v24.8H // .............................*................................... + smlal v30.4S, v20.4H, v16.4H // ................................*................................ + smlal2 v15.4S, v20.8H, v16.8H // .................................*............................... + uzp1 v7.8H, v19.8H, v31.8H // ....................................*............................ + uzp2 v20.8H, v19.8H, v31.8H // .....................................*........................... + uzp1 v16.8H, v17.8H, v18.8H // ........................................*........................ + uzp2 v9.8H, v17.8H, v18.8H // .........................................*....................... + smlal v11.4S, v7.4H, v16.4H // ...........................................*..................... + smlal2 v8.4S, v7.8H, v16.8H // ............................................*.................... + smlal v30.4S, v7.4H, v9.4H // ...............................................*................. + smlal2 v15.4S, v7.8H, v9.8H // ................................................*................ + smlal v11.4S, v20.4H, v25.4H // .............................................*................... + smlal2 v8.4S, v20.8H, v25.8H // ..............................................*.................. + smlal v30.4S, v20.4H, v16.4H // .................................................*............... + smlal2 v15.4S, v20.8H, v16.8H // ..................................................*.............. + ldr q16, [x2, #16] // .....e........................................................... + uzp1 v7.8H, v11.8H, v8.8H // ...................................................*............. + uzp1 v20.8H, v30.8H, v15.8H // ........................................................*........ + mul v7.8H, v7.8H, v2.8H // ....................................................*............ + mul v20.8H, v20.8H, v2.8H // .........................................................*....... + zip2 v9.8H, v27.8H, v10.8H // ..............................................................l.. + zip1 v27.8H, v27.8H, v10.8H // .............................................................l... + smlal v11.4S, v7.4H, v0.4H // .....................................................*........... + smlal2 v8.4S, v7.8H, v0.8H // ......................................................*.......... + smlal v30.4S, v20.4H, v0.4H // ..........................................................*...... + smlal2 v15.4S, v20.8H, v0.8H // ...........................................................*..... + str q27, [x0], #32 // ...............................................................l. + uzp2 v27.8H, v11.8H, v8.8H // .......................................................*......... + str q9, [x0, #-16] // ................................................................l + uzp2 v10.8H, v30.8H, v15.8H // ............................................................*.... + ldr q30, [x1, #16] // .e............................................................... + ldr q15, [x2], #32 // ....e............................................................ + ld1 {v9.8H}, [x3], #16 // ........e........................................................ + ldr q21, [x4], #32 // .................e............................................... + ldr q6, [x4, #-16] // ..................e.............................................. + ldr q1, [x5], #32 // .....................e........................................... + ldr q12, [x5, #-16] // ......................e.......................................... + ld1 {v24.8H}, [x6], #16 // .........................e....................................... + ldr q19, [x7], #32 // ..................................e.............................. + ldr q31, [x7, #-16] // ...................................e............................. + ldr q17, [x8], #32 // ......................................e.......................... + ldr q18, [x8, #-16] // .......................................e......................... + ld1 {v25.8H}, [x9], #16 // ..........................................e...................... + + // ---------------------------------------------------------------- new position -----------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------ + // ldr q12, [x1], #32 // ............................*................................................................~.................................................. + // ldr q13, [x1, #-16] // ...............e............'...................................................~............'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'..*.............................................................'..~............................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'...*............................................................'...~.............................................. + // ldr q12, [x2], #32 // ................e...........'....................................................~...........'.................................................. + // ldr q13, [x2, #-16] // e...........................'....................................~...........................'....................................~............. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'*...............................................................'~................................................. + // uzp2 v6.8h, v12.8h, v13.8h // ............................'.*..............................................................'.~................................................ + // ld1 {v7.8h}, [x3], #16 // .................e..........'.....................................................~..........'.................................................. + // smull v8.4s, v3.4h, v5.4h // ............................'......*.........................................................'......~........................................... + // smull2 v10.4s, v3.8h, v5.8h // ............................'.......*........................................................'.......~.......................................... + // smlal v8.4s, v4.4h, v7.4h // ............................'..........*.....................................................'..........~....................................... + // smlal2 v10.4s, v4.8h, v7.8h // ............................'...........*....................................................'...........~...................................... + // smull v9.4s, v3.4h, v6.4h // ............................'....*...........................................................'....~............................................. + // smull2 v11.4s, v3.8h, v6.8h // ............................'.....*..........................................................'.....~............................................ + // smlal v9.4s, v4.4h, v5.4h // ............................'........*.......................................................'........~......................................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.........*......................................................'.........~........................................ + // ldr q12, [x4], #32 // ..................e.........'......................................................~.........'.................................................. + // ldr q13, [x4, #-16] // ...................e........'.......................................................~........'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'............*...................................................'............~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.............*..................................................'.............~.................................... + // ldr q12, [x5], #32 // ....................e.......'........................................................~.......'.................................................. + // ldr q13, [x5, #-16] // .....................e......'.........................................................~......'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..............*.................................................'..............~................................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...............*................................................'...............~.................................. + // ld1 {v7.8h}, [x6], #16 // ......................e.....'..........................................................~.....'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'................*...............................................'................~................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.................*..............................................'.................~................................ + // smlal v8.4s, v4.4h, v7.4h // ............................'....................*...........................................'....................~............................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.....................*..........................................'.....................~............................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..................*.............................................'..................~............................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...................*............................................'...................~.............................. + // smlal v9.4s, v4.4h, v5.4h // ............................'......................*.........................................'......................~........................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.......................*........................................'.......................~.......................... + // ldr q12, [x7], #32 // .......................e....'...........................................................~....'.................................................. + // ldr q13, [x7, #-16] // ........................e...'............................................................~...'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'........................*.......................................'........................~......................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.........................*......................................'.........................~........................ + // ldr q12, [x8], #32 // .........................e..'.............................................................~..'.................................................. + // ldr q13, [x8, #-16] // ..........................e.'..............................................................~.'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..........................*.....................................'..........................~....................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...........................*....................................'...........................~...................... + // ld1 {v7.8h}, [x9], #16 // ...........................e'...............................................................~'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'............................*...................................'............................~..................... + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.............................*..................................'.............................~.................... + // smlal v8.4s, v4.4h, v7.4h // ............................'................................*...............................'................................~................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.................................*..............................'.................................~................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..............................*.................................'..............................~................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...............................*................................'...............................~.................. + // smlal v9.4s, v4.4h, v5.4h // ............................'..................................*.............................'..................................~............... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'...................................*............................'...................................~.............. + // uzp1 v28.8h, v8.8h, v10.8h // .~..........................'.....................................*..........................'.....................................~............ + // mul v28.8h, v28.8h, v2.8h // ...~........................'.......................................*........................'.......................................~.......... + // smlal v8.4s, v28.4h, v0.4h // .......~....................'...........................................*....................'...........................................~...... + // smlal2 v10.4s, v28.8h, v0.8h // ........~...................'............................................*...................'............................................~..... + // uzp2 v26.8h, v8.8h, v10.8h // ............~...............'................................................*...............'................................................~. + // uzp1 v28.8h, v9.8h, v11.8h // ..~.........................'......................................*.........................'......................................~........... + // mul v28.8h, v28.8h, v2.8h // ....~.......................'........................................*.......................'........................................~......... + // smlal v9.4s, v28.4h, v0.4h // .........~..................'.............................................*..................'.............................................~.... + // smlal2 v11.4s, v28.8h, v0.8h // ..........~.................'..............................................*.................'..............................................~... + // uzp2 v27.8h, v9.8h, v11.8h // ..............~.............'..................................................*.............'.................................................. + // zip1 v12.8h, v26.8h, v27.8h // ......~.....................'..........................................~.....................'..........................................l....... + // zip2 v13.8h, v26.8h, v27.8h // .....~......................'.........................................~......................'.........................................l........ + // str q12, [x0], #32 // ...........~................'...............................................~................'...............................................l.. + // str q13, [x0, #-16] // .............~..............'.................................................~..............'.................................................l + + sub count, count, #1 + cbnz count, 1b + // Instructions: 55 + // Expected cycles: 61 + // Expected IPC: 0.90 + + // Cycle bound: 61.0 + // IPC bound: 0.90 + + // Wall time: 8.41s + // User time: 8.41s + + // ----------------- original position ------------------> + // 0 25 50 + // |------------------------|------------------------|---- + ldr q7, [x1], #32 // *...................................................... + uzp1 v20.8H, v15.8H, v16.8H // .*..................................................... + uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + uzp1 v23.8H, v7.8H, v30.8H // ...*................................................... + uzp2 v11.8H, v7.8H, v30.8H // ....*.................................................. + smull2 v8.4S, v23.8H, v20.8H // ........*.............................................. + smull v5.4S, v23.4H, v20.4H // .......*............................................... + smull2 v30.4S, v23.8H, v15.8H // ......*................................................ + uzp1 v28.8H, v1.8H, v12.8H // ...............*....................................... + smlal2 v8.4S, v11.8H, v9.8H // ............*.......................................... + smlal v5.4S, v11.4H, v9.4H // ...........*........................................... + uzp1 v3.8H, v21.8H, v6.8H // .............*......................................... + smull v16.4S, v23.4H, v15.4H // .....*................................................. + smlal2 v8.4S, v3.8H, v28.8H // ..................*.................................... + smlal v5.4S, v3.4H, v28.4H // .................*..................................... + uzp2 v29.8H, v21.8H, v6.8H // ..............*........................................ + uzp1 v7.8H, v17.8H, v18.8H // ...........................*........................... + smlal2 v8.4S, v29.8H, v24.8H // ......................*................................ + uzp1 v14.8H, v19.8H, v31.8H // .........................*............................. + smlal v16.4S, v11.4H, v20.4H // .........*............................................. + smlal2 v30.4S, v11.8H, v20.8H // ..........*............................................ + smlal2 v8.4S, v14.8H, v7.8H // ..............................*........................ + uzp2 v20.8H, v1.8H, v12.8H // ................*...................................... + uzp2 v21.8H, v19.8H, v31.8H // ..........................*............................ + smlal2 v30.4S, v3.8H, v20.8H // ....................*.................................. + smlal v16.4S, v3.4H, v20.4H // ...................*................................... + smlal v5.4S, v29.4H, v24.4H // .....................*................................. + uzp2 v9.8H, v17.8H, v18.8H // ............................*.......................... + smlal2 v30.4S, v29.8H, v28.8H // ........................*.............................. + smlal v16.4S, v29.4H, v28.4H // .......................*............................... + smlal v5.4S, v14.4H, v7.4H // .............................*......................... + smlal2 v8.4S, v21.8H, v25.8H // ..................................*.................... + smlal2 v30.4S, v14.8H, v9.8H // ................................*...................... + smlal v16.4S, v14.4H, v9.4H // ...............................*....................... + smlal v5.4S, v21.4H, v25.4H // .................................*..................... + zip1 v20.8H, v27.8H, v10.8H // ..........................................*............ + smlal2 v30.4S, v21.8H, v7.8H // ....................................*.................. + smlal v16.4S, v21.4H, v7.4H // ...................................*................... + uzp1 v7.8H, v5.8H, v8.8H // .....................................*................. + str q20, [x0], #32 // ...............................................*....... + mul v15.8H, v7.8H, v2.8H // .......................................*............... + uzp1 v7.8H, v16.8H, v30.8H // ......................................*................ + zip2 v31.8H, v27.8H, v10.8H // .........................................*............. + mul v20.8H, v7.8H, v2.8H // ........................................*.............. + smlal v5.4S, v15.4H, v0.4H // ...........................................*........... + smlal2 v8.4S, v15.8H, v0.8H // ............................................*.......... + str q31, [x0, #-16] // .................................................*..... + smlal2 v30.4S, v20.8H, v0.8H // ..............................................*........ + smlal v16.4S, v20.4H, v0.4H // .............................................*......... + uzp2 v15.8H, v5.8H, v8.8H // ................................................*...... + uzp2 v20.8H, v16.8H, v30.8H // ..................................................*.... + zip1 v7.8H, v15.8H, v20.8H // ....................................................*.. + zip2 v20.8H, v15.8H, v20.8H // ...................................................*... + str q7, [x0], #32 // .....................................................*. + str q20, [x0, #-16] // ......................................................* + + // -------------------- new position --------------------> + // 0 25 50 + // |------------------------|------------------------|---- + // ldr q20, [x1], #32 // *...................................................... + // uzp1 v7.8H, v15.8H, v16.8H // .*..................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ...*................................................... + // uzp2 v20.8H, v20.8H, v30.8H // ....*.................................................. + // smull v30.4S, v8.4H, v15.4H // ............*.......................................... + // smull2 v15.4S, v8.8H, v15.8H // .......*............................................... + // smull v11.4S, v8.4H, v7.4H // ......*................................................ + // smull2 v8.4S, v8.8H, v7.8H // .....*................................................. + // smlal v30.4S, v20.4H, v7.4H // ...................*................................... + // smlal2 v15.4S, v20.8H, v7.8H // ....................*.................................. + // smlal v11.4S, v20.4H, v9.4H // ..........*............................................ + // smlal2 v8.4S, v20.8H, v9.8H // .........*............................................. + // uzp1 v7.8H, v21.8H, v6.8H // ...........*........................................... + // uzp2 v20.8H, v21.8H, v6.8H // ...............*....................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........*.............................................. + // uzp2 v9.8H, v1.8H, v12.8H // ......................*................................ + // smlal v11.4S, v7.4H, v16.4H // ..............*........................................ + // smlal2 v8.4S, v7.8H, v16.8H // .............*......................................... + // smlal v30.4S, v7.4H, v9.4H // .........................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ........................*.............................. + // smlal v11.4S, v20.4H, v24.4H // ..........................*............................ + // smlal2 v8.4S, v20.8H, v24.8H // .................*..................................... + // smlal v30.4S, v20.4H, v16.4H // .............................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ............................*.......................... + // uzp1 v7.8H, v19.8H, v31.8H // ..................*.................................... + // uzp2 v20.8H, v19.8H, v31.8H // .......................*............................... + // uzp1 v16.8H, v17.8H, v18.8H // ................*...................................... + // uzp2 v9.8H, v17.8H, v18.8H // ...........................*........................... + // smlal v11.4S, v7.4H, v16.4H // ..............................*........................ + // smlal2 v8.4S, v7.8H, v16.8H // .....................*................................. + // smlal v30.4S, v7.4H, v9.4H // .................................*..................... + // smlal2 v15.4S, v7.8H, v9.8H // ................................*...................... + // smlal v11.4S, v20.4H, v25.4H // ..................................*.................... + // smlal2 v8.4S, v20.8H, v25.8H // ...............................*....................... + // smlal v30.4S, v20.4H, v16.4H // .....................................*................. + // smlal2 v15.4S, v20.8H, v16.8H // ....................................*.................. + // uzp1 v7.8H, v11.8H, v8.8H // ......................................*................ + // uzp1 v20.8H, v30.8H, v15.8H // .........................................*............. + // mul v7.8H, v7.8H, v2.8H // ........................................*.............. + // mul v20.8H, v20.8H, v2.8H // ...........................................*........... + // zip2 v9.8H, v27.8H, v10.8H // ..........................................*............ + // zip1 v27.8H, v27.8H, v10.8H // ...................................*................... + // smlal v11.4S, v7.4H, v0.4H // ............................................*.......... + // smlal2 v8.4S, v7.8H, v0.8H // .............................................*......... + // smlal v30.4S, v20.4H, v0.4H // ................................................*...... + // smlal2 v15.4S, v20.8H, v0.8H // ...............................................*....... + // str q27, [x0], #32 // .......................................*............... + // uzp2 v27.8H, v11.8H, v8.8H // .................................................*..... + // str q9, [x0, #-16] // ..............................................*........ + // uzp2 v10.8H, v30.8H, v15.8H // ..................................................*.... + // zip2 v9.8H, v27.8H, v10.8H // ....................................................*.. + // zip1 v27.8H, v27.8H, v10.8H // ...................................................*... + // str q27, [x0], #32 // .....................................................*. + // str q9, [x0, #-16] // ......................................................* + + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) + // Instructions: 114 + // Expected cycles: 153 + // Expected IPC: 0.75 + // + // Cycle bound: 153.0 + // IPC bound: 0.75 + // + // Wall time: 0.69s + // User time: 0.69s + // + // ----------------------------------------------- original position -----------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + ldr q23, [x2, #16] // .*................................................................................................................ + ldr q19, [x2], #32 // *................................................................................................................. + ldr q17, [x5], #32 // ..*............................................................................................................... + uzp2 v13.8H, v19.8H, v23.8H // ..........*....................................................................................................... + uzp1 v19.8H, v19.8H, v23.8H // ...........*...................................................................................................... + ldr q23, [x5, #-16] // ...*.............................................................................................................. + ldr q30, [x1, #16] // .....*............................................................................................................ + uzp2 v9.8H, v17.8H, v23.8H // ....*............................................................................................................. + uzp1 v23.8H, v17.8H, v23.8H // .......*.......................................................................................................... + ldr q17, [x1], #32 // ......*........................................................................................................... + ldr q10, [x7, #16] // .............*.................................................................................................... + uzp1 v12.8H, v17.8H, v30.8H // ........*......................................................................................................... + uzp2 v17.8H, v17.8H, v30.8H // .........*........................................................................................................ + smull2 v30.4S, v12.8H, v13.8H // ............*..................................................................................................... + smull v13.4S, v12.4H, v13.4H // ............................................*..................................................................... + smull2 v22.4S, v12.8H, v19.8H // .....................................*............................................................................ + smull v12.4S, v12.4H, v19.4H // ..........................................*....................................................................... + smlal2 v30.4S, v17.8H, v19.8H // ...............................*.................................................................................. + smlal v13.4S, v17.4H, v19.4H // ...............................................*.................................................................. + ldr q19, [x4], #32 // ....................*............................................................................................. + ldr q16, [x4, #-16] // .....................*............................................................................................ + ld1 {v8.8H}, [x3], #16 // ................................*................................................................................. + uzp1 v26.8H, v19.8H, v16.8H // .......................*.......................................................................................... + uzp2 v19.8H, v19.8H, v16.8H // ........................*......................................................................................... + smlal2 v30.4S, v26.8H, v9.8H // .................................*................................................................................ + smlal v13.4S, v26.4H, v9.4H // ..................................................*............................................................... + smlal2 v22.4S, v17.8H, v8.8H // ........................................*......................................................................... + smlal v12.4S, v17.4H, v8.4H // .................................................*................................................................ + smlal2 v30.4S, v19.8H, v23.8H // ...................................*.............................................................................. + smlal v13.4S, v19.4H, v23.4H // .......................................................*.......................................................... + smlal2 v22.4S, v26.8H, v23.8H // ...........................................*...................................................................... + smlal v12.4S, v26.4H, v23.4H // .....................................................*............................................................ + ldr q23, [x7], #32 // ......................*........................................................................................... + ldr q17, [x8, #16] // ..............*................................................................................................... + uzp1 v9.8H, v23.8H, v10.8H // ..........................*....................................................................................... + uzp2 v23.8H, v23.8H, v10.8H // ....................................*............................................................................. + ldr q10, [x10], #32 // ...............*.................................................................................................. + ldr q16, [x10, #-16] // ................*................................................................................................. + ld1 {v8.8H}, [x12], #16 // .................*................................................................................................ + uzp1 v26.8H, v10.8H, v16.8H // ..................*............................................................................................... + uzp2 v10.8H, v10.8H, v16.8H // ...................*.............................................................................................. + ld1 {v16.8H}, [x6], #16 // .........................*........................................................................................ + ldr q3, [x11, #16] // ...........................*...................................................................................... + smlal2 v22.4S, v19.8H, v16.8H // ..............................................*................................................................... + smlal v12.4S, v19.4H, v16.4H // ........................................................*......................................................... + ldr q19, [x11], #32 // ............................*..................................................................................... + ld1 {v16.8H}, [x9], #16 // .............................*.................................................................................... + uzp1 v4.8H, v19.8H, v3.8H // ..................................*............................................................................... + uzp2 v19.8H, v19.8H, v3.8H // .......................................*.......................................................................... + ldr q3, [x8], #32 // ..............................*................................................................................... + ldr q31, [x2], #32 // ......................................*........................................................................... + uzp1 v6.8H, v3.8H, v17.8H // ...................................................*.............................................................. + uzp2 v17.8H, v3.8H, v17.8H // .........................................................*........................................................ + smlal2 v22.4S, v9.8H, v6.8H // ..........................................................*....................................................... + smlal2 v30.4S, v9.8H, v17.8H // ...........................................................*...................................................... + smlal v13.4S, v9.4H, v17.4H // ............................................................*..................................................... + smlal v12.4S, v9.4H, v6.4H // .............................................................*.................................................... + smlal2 v22.4S, v23.8H, v16.8H // ..............................................................*................................................... + smlal2 v30.4S, v23.8H, v6.8H // ...............................................................*.................................................. + smlal v13.4S, v23.4H, v6.4H // ................................................................*................................................. + smlal v12.4S, v23.4H, v16.4H // .................................................................*................................................ + smlal2 v22.4S, v26.8H, v4.8H // ..................................................................*............................................... + smlal2 v30.4S, v26.8H, v19.8H // ...................................................................*.............................................. + smlal v13.4S, v26.4H, v19.4H // ....................................................................*............................................. + smlal v12.4S, v26.4H, v4.4H // .....................................................................*............................................ + smlal2 v22.4S, v10.8H, v8.8H // ......................................................................*........................................... + smlal2 v30.4S, v10.8H, v4.8H // .......................................................................*.......................................... + smlal v13.4S, v10.4H, v4.4H // ........................................................................*......................................... + smlal v12.4S, v10.4H, v8.4H // .........................................................................*........................................ + ldr q19, [x2, #-16] // .........................................*........................................................................ + uzp1 v23.8H, v13.8H, v30.8H // ...........................................................................*...................................... + uzp1 v17.8H, v12.8H, v22.8H // ....................................................................................*............................. + mul v23.8H, v23.8H, v2.8H // .............................................................................*.................................... + uzp2 v21.8H, v31.8H, v19.8H // ................................................................................*................................. + uzp1 v19.8H, v31.8H, v19.8H // ...................................................................................*.............................. + mul v17.8H, v17.8H, v2.8H // .....................................................................................*............................ + smlal v13.4S, v23.4H, v0.4H // .................................................................................*................................ + smlal2 v30.4S, v23.8H, v0.8H // ..................................................................................*............................... + ldr q23, [x5], #32 // .............................................*.................................................................... + smlal2 v22.4S, v17.8H, v0.8H // ...........................................................................................................*...... + uzp2 v15.8H, v13.8H, v30.8H // ......................................................................................*........................... + smlal v12.4S, v17.4H, v0.4H // ............................................................................................................*..... + ldr q17, [x5, #-16] // ................................................*................................................................. + ldr q13, [x1, #16] // ......................................................*........................................................... + uzp2 v27.8H, v23.8H, v17.8H // ....................................................*............................................................. + uzp1 v28.8H, v23.8H, v17.8H // ............................................................................*..................................... + uzp2 v7.8H, v12.8H, v22.8H // ...............................................................................................................*.. + ldr q23, [x1], #32 // ..........................................................................*....................................... + zip1 v5.8H, v7.8H, v15.8H // .................................................................................................................* + ldr q3, [x7, #16] // ........................................................................................*......................... + uzp1 v31.8H, v23.8H, v13.8H // ..............................................................................*................................... + uzp2 v16.8H, v23.8H, v13.8H // ...............................................................................*.................................. + smull2 v24.4S, v31.8H, v21.8H // .......................................................................................*.......................... + ldr q6, [x8, #16] // .........................................................................................*........................ + ldr q23, [x10], #32 // ..........................................................................................*....................... + smlal2 v24.4S, v16.8H, v19.8H // ..........................................................................................................*....... + ldr q17, [x10, #-16] // ...........................................................................................*...................... + ld1 {v22.8H}, [x12], #16 // ............................................................................................*..................... + uzp1 v30.8H, v23.8H, v17.8H // .............................................................................................*.................... + uzp2 v11.8H, v23.8H, v17.8H // ..............................................................................................*................... + ldr q23, [x4], #32 // ...............................................................................................*.................. + ldr q17, [x4, #-16] // ................................................................................................*................. + ldr q4, [x7], #32 // .................................................................................................*................ + uzp1 v20.8H, v23.8H, v17.8H // ..................................................................................................*............... + uzp2 v26.8H, v23.8H, v17.8H // ...................................................................................................*.............. + uzp1 v9.8H, v4.8H, v3.8H // .....................................................................................................*............ + smlal2 v24.4S, v20.8H, v27.8H // ..............................................................................................................*... + ld1 {v8.8H}, [x6], #16 // ....................................................................................................*............. + ldr q25, [x11, #16] // ......................................................................................................*........... + ldr q29, [x11], #32 // .......................................................................................................*.......... + ld1 {v12.8H}, [x9], #16 // ........................................................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // ................................................................................................................*. + ldr q14, [x8], #32 // .........................................................................................................*........ + ld1 {v23.8H}, [x3], #16 // .............................................................................................................*.... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q3, [x2], #32 // .*................................................................................................................ + // ldr q17, [x2, #-16] // *................................................................................................................. + // ldr q21, [x5], #32 // ..*............................................................................................................... + // ldr q19, [x5, #-16] // .....*............................................................................................................ + // uzp2 v27.8H, v21.8H, v19.8H // .......*.......................................................................................................... + // ldr q25, [x1, #16] // ......*........................................................................................................... + // ldr q22, [x1], #32 // .........*........................................................................................................ + // uzp1 v28.8H, v21.8H, v19.8H // ........*......................................................................................................... + // uzp1 v31.8H, v22.8H, v25.8H // ...........*...................................................................................................... + // uzp2 v16.8H, v22.8H, v25.8H // ............*..................................................................................................... + // uzp2 v21.8H, v3.8H, v17.8H // ...*.............................................................................................................. + // uzp1 v19.8H, v3.8H, v17.8H // ....*............................................................................................................. + // smull2 v24.4S, v31.8H, v21.8H // .............*.................................................................................................... + // ldr q3, [x7, #16] // ..........*....................................................................................................... + // ldr q6, [x8, #16] // .................................*................................................................................ + // ldr q8, [x10], #32 // ....................................*............................................................................. + // ldr q26, [x10, #-16] // .....................................*............................................................................ + // ld1 {v22.8H}, [x12], #16 // ......................................*........................................................................... + // uzp1 v30.8H, v8.8H, v26.8H // .......................................*.......................................................................... + // uzp2 v11.8H, v8.8H, v26.8H // ........................................*......................................................................... + // ldr q8, [x4], #32 // ...................*.............................................................................................. + // ldr q26, [x4, #-16] // ....................*............................................................................................. + // ldr q4, [x7], #32 // ................................*................................................................................. + // uzp1 v20.8H, v8.8H, v26.8H // ......................*........................................................................................... + // uzp2 v26.8H, v8.8H, v26.8H // .......................*.......................................................................................... + // ld1 {v8.8H}, [x6], #16 // .........................................*........................................................................ + // uzp1 v9.8H, v4.8H, v3.8H // ..................................*............................................................................... + // ldr q25, [x11, #16] // ..........................................*....................................................................... + // ldr q29, [x11], #32 // .............................................*.................................................................... + // ld1 {v12.8H}, [x9], #16 // ..............................................*................................................................... + // ldr q14, [x8], #32 // .................................................*................................................................ + // smlal2 v24.4S, v16.8H, v19.8H // .................*................................................................................................ + // ld1 {v23.8H}, [x3], #16 // .....................*............................................................................................ + // smlal2 v24.4S, v20.8H, v27.8H // ........................*......................................................................................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................*.................................................................. + // smlal2 v24.4S, v26.8H, v28.8H // ............................*..................................................................................... + // uzp2 v4.8H, v4.8H, v3.8H // ...................................*.............................................................................. + // smull2 v13.4S, v31.8H, v19.8H // ...............*.................................................................................................. + // ldr q3, [x2], #32 // ..................................................*............................................................... + // uzp2 v1.8H, v29.8H, v25.8H // ................................................*................................................................. + // smlal2 v13.4S, v16.8H, v23.8H // ..........................*....................................................................................... + // ldr q17, [x2, #-16] // .....................................................................*............................................ + // smull v18.4S, v31.4H, v19.4H // ................*................................................................................................. + // smlal2 v13.4S, v20.8H, v28.8H // ..............................*................................................................................... + // smull v29.4S, v31.4H, v21.4H // ..............*................................................................................................... + // ldr q21, [x5], #32 // ..............................................................................*................................... + // smlal2 v13.4S, v26.8H, v8.8H // ...........................................*...................................................................... + // smlal v29.4S, v16.4H, v19.4H // ..................*............................................................................................... + // ldr q19, [x5, #-16] // ..................................................................................*............................... + // smlal v18.4S, v16.4H, v23.4H // ...........................*...................................................................................... + // smlal v29.4S, v20.4H, v27.4H // .........................*........................................................................................ + // uzp1 v31.8H, v14.8H, v6.8H // ...................................................*.............................................................. + // uzp2 v27.8H, v21.8H, v19.8H // ....................................................................................*............................. + // smlal v18.4S, v20.4H, v28.4H // ...............................*.................................................................................. + // ldr q25, [x1, #16] // ...................................................................................*.............................. + // smlal v29.4S, v26.4H, v28.4H // .............................*.................................................................................... + // smlal v18.4S, v26.4H, v8.4H // ............................................*..................................................................... + // uzp2 v26.8H, v14.8H, v6.8H // ....................................................*............................................................. + // smlal2 v13.4S, v9.8H, v31.8H // .....................................................*............................................................ + // smlal2 v24.4S, v9.8H, v26.8H // ......................................................*........................................................... + // smlal v29.4S, v9.4H, v26.4H // .......................................................*.......................................................... + // smlal v18.4S, v9.4H, v31.4H // ........................................................*......................................................... + // smlal2 v13.4S, v4.8H, v12.8H // .........................................................*........................................................ + // smlal2 v24.4S, v4.8H, v31.8H // ..........................................................*....................................................... + // smlal v29.4S, v4.4H, v31.4H // ...........................................................*...................................................... + // smlal v18.4S, v4.4H, v12.4H // ............................................................*..................................................... + // smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................................................... + // smlal2 v24.4S, v30.8H, v1.8H // ..............................................................*................................................... + // smlal v29.4S, v30.4H, v1.4H // ...............................................................*.................................................. + // smlal v18.4S, v30.4H, v10.4H // ................................................................*................................................. + // smlal2 v13.4S, v11.8H, v22.8H // .................................................................*................................................ + // smlal2 v24.4S, v11.8H, v10.8H // ..................................................................*............................................... + // smlal v29.4S, v11.4H, v10.4H // ...................................................................*.............................................. + // smlal v18.4S, v11.4H, v22.4H // ....................................................................*............................................. + // ldr q22, [x1], #32 // .......................................................................................*.......................... + // uzp1 v31.8H, v29.8H, v24.8H // ......................................................................*........................................... + // uzp1 v28.8H, v21.8H, v19.8H // .....................................................................................*............................ + // mul v19.8H, v31.8H, v2.8H // ........................................................................*......................................... + // uzp1 v31.8H, v22.8H, v25.8H // ..........................................................................................*....................... + // uzp2 v16.8H, v22.8H, v25.8H // ...........................................................................................*...................... + // uzp2 v21.8H, v3.8H, v17.8H // .........................................................................*........................................ + // smlal v29.4S, v19.4H, v0.4H // ............................................................................*..................................... + // smlal2 v24.4S, v19.8H, v0.8H // .............................................................................*.................................... + // uzp1 v19.8H, v3.8H, v17.8H // ..........................................................................*....................................... + // uzp1 v26.8H, v18.8H, v13.8H // .......................................................................*.......................................... + // mul v23.8H, v26.8H, v2.8H // ...........................................................................*...................................... + // uzp2 v15.8H, v29.8H, v24.8H // ................................................................................*................................. + // smull2 v24.4S, v31.8H, v21.8H // ............................................................................................*..................... + // ldr q3, [x7, #16] // .........................................................................................*........................ + // ldr q6, [x8, #16] // .............................................................................................*.................... + // ldr q8, [x10], #32 // ..............................................................................................*................... + // ldr q26, [x10, #-16] // ................................................................................................*................. + // ld1 {v22.8H}, [x12], #16 // .................................................................................................*................ + // uzp1 v30.8H, v8.8H, v26.8H // ..................................................................................................*............... + // uzp2 v11.8H, v8.8H, v26.8H // ...................................................................................................*.............. + // ldr q8, [x4], #32 // ....................................................................................................*............. + // ldr q26, [x4, #-16] // .....................................................................................................*............ + // ldr q4, [x7], #32 // ......................................................................................................*........... + // uzp1 v20.8H, v8.8H, v26.8H // .......................................................................................................*.......... + // uzp2 v26.8H, v8.8H, v26.8H // ........................................................................................................*......... + // ld1 {v8.8H}, [x6], #16 // ...........................................................................................................*...... + // uzp1 v9.8H, v4.8H, v3.8H // .........................................................................................................*........ + // ldr q25, [x11, #16] // ............................................................................................................*..... + // ldr q29, [x11], #32 // .............................................................................................................*.... + // ld1 {v12.8H}, [x9], #16 // ..............................................................................................................*... + // ldr q14, [x8], #32 // ................................................................................................................*. + // smlal2 v24.4S, v16.8H, v19.8H // ...............................................................................................*.................. + // smlal2 v13.4S, v23.8H, v0.8H // ...............................................................................*.................................. + // smlal v18.4S, v23.4H, v0.4H // .................................................................................*................................ + // ld1 {v23.8H}, [x3], #16 // .................................................................................................................* + // smlal2 v24.4S, v20.8H, v27.8H // ..........................................................................................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // ......................................................................................*........................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................................................................................*.. + // zip1 v5.8H, v7.8H, v15.8H // ........................................................................................*......................... + + sub count, count, #2 +1: + // Instructions: 82 + // Expected cycles: 102 + // Expected IPC: 0.80 + // + // Cycle bound: 102.0 + // IPC bound: 0.80 + // + // Wall time: 15.93s + // User time: 15.93s + // + // ------------------------------- original position -------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + smlal2 v24.4S, v26.8H, v28.8H // .................................*................................................ + uzp2 v4.8H, v4.8H, v3.8H // .....................................*............................................ + smull2 v13.4S, v31.8H, v19.8H // ..........*....................................................................... + ldr q3, [x2], #32 // ....e............................................................................. + uzp2 v1.8H, v29.8H, v25.8H // ..........................................................*....................... + smlal2 v13.4S, v16.8H, v23.8H // ............*..................................................................... + ldr q17, [x2, #-16] // .....e............................................................................ + smull v18.4S, v31.4H, v19.4H // .........*........................................................................ + smlal2 v13.4S, v20.8H, v28.8H // ...........................*...................................................... + smull v29.4S, v31.4H, v21.4H // .............*.................................................................... + ldr q21, [x5], #32 // .....................e............................................................ + smlal2 v13.4S, v26.8H, v8.8H // .............................*.................................................... + smlal v29.4S, v16.4H, v19.4H // ...............*.................................................................. + ldr q19, [x5, #-16] // ......................e........................................................... + smlal v18.4S, v16.4H, v23.4H // ...........*...................................................................... + smlal v29.4S, v20.4H, v27.4H // ..............................*................................................... + uzp1 v31.8H, v14.8H, v6.8H // ........................................*......................................... + uzp2 v27.8H, v21.8H, v19.8H // ........................e......................................................... + smlal v18.4S, v20.4H, v28.4H // ..........................*....................................................... + ldr q25, [x1, #16] // .e................................................................................ + smlal v29.4S, v26.4H, v28.4H // ................................*................................................. + smlal v18.4S, v26.4H, v8.4H // ............................*..................................................... + uzp2 v26.8H, v14.8H, v6.8H // .........................................*........................................ + smlal2 v13.4S, v9.8H, v31.8H // ............................................*..................................... + smlal2 v24.4S, v9.8H, v26.8H // ................................................*................................. + smlal v29.4S, v9.4H, v26.4H // ...............................................*.................................. + smlal v18.4S, v9.4H, v31.4H // ...........................................*...................................... + smlal2 v13.4S, v4.8H, v12.8H // ..............................................*................................... + smlal2 v24.4S, v4.8H, v31.8H // ..................................................*............................... + smlal v29.4S, v4.4H, v31.4H // .................................................*................................ + smlal v18.4S, v4.4H, v12.4H // .............................................*.................................... + smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................... + smlal2 v24.4S, v30.8H, v1.8H // .................................................................*................ + smlal v29.4S, v30.4H, v1.4H // ................................................................*................. + smlal v18.4S, v30.4H, v10.4H // ............................................................*..................... + smlal2 v13.4S, v11.8H, v22.8H // ...............................................................*.................. + smlal2 v24.4S, v11.8H, v10.8H // ...................................................................*.............. + smlal v29.4S, v11.4H, v10.4H // ..................................................................*............... + smlal v18.4S, v11.4H, v22.4H // ..............................................................*................... + ldr q22, [x1], #32 // e................................................................................. + uzp1 v31.8H, v29.8H, v24.8H // .........................................................................*........ + uzp1 v28.8H, v21.8H, v19.8H // .......................e.......................................................... + mul v19.8H, v31.8H, v2.8H // ..........................................................................*....... + uzp1 v31.8H, v22.8H, v25.8H // ..e............................................................................... + uzp2 v16.8H, v22.8H, v25.8H // ...e.............................................................................. + uzp2 v21.8H, v3.8H, v17.8H // .......e.......................................................................... + smlal v29.4S, v19.4H, v0.4H // ...........................................................................*...... + smlal2 v24.4S, v19.8H, v0.8H // ............................................................................*..... + uzp1 v19.8H, v3.8H, v17.8H // ......e........................................................................... + uzp1 v26.8H, v18.8H, v13.8H // ....................................................................*............. + zip2 v14.8H, v7.8H, v15.8H // ...............................................................................l.. + mul v23.8H, v26.8H, v2.8H // .....................................................................*............ + uzp2 v15.8H, v29.8H, v24.8H // .............................................................................*.... + smull2 v24.4S, v31.8H, v21.8H // ..............e................................................................... + str q14, [x0, #16] // .................................................................................l + ldr q3, [x7, #16] // ...................................e.............................................. + ldr q6, [x8, #16] // .......................................e.......................................... + ldr q8, [x10], #32 // ...................................................e.............................. + ldr q26, [x10, #-16] // ....................................................e............................. + ld1 {v22.8H}, [x12], #16 // ...........................................................e...................... + uzp1 v30.8H, v8.8H, v26.8H // .....................................................e............................ + uzp2 v11.8H, v8.8H, v26.8H // ......................................................e........................... + ldr q8, [x4], #32 // .................e................................................................ + ldr q26, [x4, #-16] // ..................e............................................................... + ldr q4, [x7], #32 // ..................................e............................................... + uzp1 v20.8H, v8.8H, v26.8H // ...................e.............................................................. + uzp2 v26.8H, v8.8H, v26.8H // ....................e............................................................. + ld1 {v8.8H}, [x6], #16 // .........................e........................................................ + uzp1 v9.8H, v4.8H, v3.8H // ....................................e............................................. + ldr q25, [x11, #16] // ........................................................e......................... + ldr q29, [x11], #32 // .......................................................e.......................... + ld1 {v12.8H}, [x9], #16 // ..........................................e....................................... + ldr q14, [x8], #32 // ......................................e........................................... + smlal2 v24.4S, v16.8H, v19.8H // ................e................................................................. + smlal2 v13.4S, v23.8H, v0.8H // .......................................................................*.......... + smlal v18.4S, v23.4H, v0.4H // ......................................................................*........... + ld1 {v23.8H}, [x3], #16 // ........e......................................................................... + smlal2 v24.4S, v20.8H, v27.8H // ...............................e.................................................. + uzp2 v7.8H, v18.8H, v13.8H // ........................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // .........................................................e........................ + str q5, [x0], #32 // ................................................................................l. + zip1 v5.8H, v7.8H, v15.8H // ..............................................................................*... + + // ----------------------------------------------------------------------------------------------------------------- new position ------------------------------------------------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 175 200 225 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|---------------- + // ldr q12, [x1], #32 // ....................................e..........................................'......................................~..........................................'......................................~......................................... + // ldr q13, [x1, #-16] // ................e..............................................................'..................~..............................................................'..................~............................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................................e......................................'..........................................~......................................'..........................................~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // .........................................e.....................................'...........................................~.....................................'...........................................~.................................... + // ldr q12, [x2], #32 // e..............................................................................'..~..............................................................................'..~............................................................................. + // ldr q13, [x2, #-16] // ...e...........................................................................'.....~...........................................................................'.....~.......................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // .............................................e.................................'...............................................~.................................'...............................................~................................ + // uzp2 v6.8h, v12.8h, v13.8h // ..........................................e....................................'............................................~....................................'............................................~................................... + // ld1 {v7.8h}, [x3], #16 // .........................................................................e.....'...........................................................................~.....'...........................................................................~.... + // smull v8.4s, v3.4h, v5.4h // ....~..........................................................................'......*..........................................................................'......~......................................................................... + // smull2 v10.4s, v3.8h, v5.8h // ...............................................................................'.*...............................................................................'.~.............................................................................. + // smlal v8.4s, v4.4h, v7.4h // ...........~...................................................................'.............*...................................................................'.............~.................................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................................................'....*............................................................................'....~........................................................................... + // smull v9.4s, v3.4h, v6.4h // ......~........................................................................'........*........................................................................'........~....................................................................... + // smull2 v11.4s, v3.8h, v6.8h // ..................................................e............................'....................................................~............................'....................................................~........................... + // smlal v9.4s, v4.4h, v5.4h // .........~.....................................................................'...........*.....................................................................'...........~.................................................................... + // smlal2 v11.4s, v4.8h, v5.8h // ......................................................................e........'........................................................................~........'........................................................................~....... + // ldr q12, [x4], #32 // ...........................................................e...................'.............................................................~...................'.............................................................~.................. + // ldr q13, [x4, #-16] // ............................................................e..................'..............................................................~..................'..............................................................~................. + // uzp1 v3.8h, v12.8h, v13.8h // ..............................................................e................'................................................................~................'................................................................~............... + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................e...............'.................................................................~...............'.................................................................~.............. + // ldr q12, [x5], #32 // .......e.......................................................................'.........~.......................................................................'.........~...................................................................... + // ldr q13, [x5, #-16] // ..........e....................................................................'............~....................................................................'............~................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // ......................................e........................................'........................................~........................................'........................................~....................................... + // uzp2 v6.8h, v12.8h, v13.8h // ..............e................................................................'................~................................................................'................~............................................................... + // ld1 {v7.8h}, [x6], #16 // ................................................................e..............'..................................................................~..............'..................................................................~............. + // smlal v8.4s, v3.4h, v5.4h // ...............~...............................................................'.................*...............................................................'.................~.............................................................. + // smlal2 v10.4s, v3.8h, v5.8h // .....~.........................................................................'.......*.........................................................................'.......~........................................................................ + // smlal v8.4s, v4.4h, v7.4h // ..................~............................................................'....................*............................................................'....................~........................................................... + // smlal2 v10.4s, v4.8h, v7.8h // ........~......................................................................'..........*......................................................................'..........~..................................................................... + // smlal v9.4s, v3.4h, v6.4h // ............~..................................................................'..............*..................................................................'..............~................................................................. + // smlal2 v11.4s, v3.8h, v6.8h // ..........................................................................e....'............................................................................~....'............................................................................~... + // smlal v9.4s, v4.4h, v5.4h // .................~.............................................................'...................*.............................................................'...................~............................................................ + // smlal2 v11.4s, v4.8h, v5.8h // ...............................................................................*.................................................................................~................................................................................ + // ldr q12, [x7], #32 // .............................................................e.................'...............................................................~.................'...............................................................~................ + // ldr q13, [x7, #-16] // ....................................................e..........................'......................................................~..........................'......................................................~......................... + // uzp1 v3.8h, v12.8h, v13.8h // .................................................................e.............'...................................................................~.............'...................................................................~............ + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................................'*................................................................................'~............................................................................... + // ldr q12, [x8], #32 // .....................................................................e.........'.......................................................................~.........'.......................................................................~........ + // ldr q13, [x8, #-16] // .....................................................e.........................'.......................................................~.........................'.......................................................~........................ + // uzp1 v5.8h, v12.8h, v13.8h // .............~.................................................................'...............*.................................................................'...............~................................................................ + // uzp2 v6.8h, v12.8h, v13.8h // ...................~...........................................................'.....................*...........................................................'.....................~.......................................................... + // ld1 {v7.8h}, [x9], #16 // ....................................................................e..........'......................................................................~..........'......................................................................~......... + // smlal v8.4s, v3.4h, v5.4h // .......................~.......................................................'.........................*.......................................................'.........................~...................................................... + // smlal2 v10.4s, v3.8h, v5.8h // ....................~..........................................................'......................*..........................................................'......................~......................................................... + // smlal v8.4s, v4.4h, v7.4h // ...........................~...................................................'.............................*...................................................'.............................~.................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ........................~......................................................'..........................*......................................................'..........................~..................................................... + // smlal v9.4s, v3.4h, v6.4h // ......................~........................................................'........................*........................................................'........................~....................................................... + // smlal2 v11.4s, v3.8h, v6.8h // .....................~.........................................................'.......................*.........................................................'.......................~........................................................ + // smlal v9.4s, v4.4h, v5.4h // ..........................~....................................................'............................*....................................................'............................~................................................... + // smlal2 v11.4s, v4.8h, v5.8h // .........................~.....................................................'...........................*.....................................................'...........................~.................................................... + // ldr q12, [x10], #32 // ......................................................e........................'........................................................~........................'........................................................~....................... + // ldr q13, [x10, #-16] // .......................................................e.......................'.........................................................~.......................'.........................................................~...................... + // uzp1 v3.8h, v12.8h, v13.8h // .........................................................e.....................'...........................................................~.....................'...........................................................~.................... + // uzp2 v4.8h, v12.8h, v13.8h // ..........................................................e....................'............................................................~....................'............................................................~................... + // ldr q12, [x11], #32 // ...................................................................e...........'.....................................................................~...........'.....................................................................~.......... + // ldr q13, [x11, #-16] // ..................................................................e............'....................................................................~............'....................................................................~........... + // uzp1 v5.8h, v12.8h, v13.8h // ............................................................................e..'..............................................................................~..'..............................................................................~. + // uzp2 v6.8h, v12.8h, v13.8h // .~.............................................................................'...*.............................................................................'...~............................................................................ + // ld1 {v7.8h}, [x12], #16 // ........................................................e......................'..........................................................~......................'..........................................................~..................... + // smlal v8.4s, v3.4h, v5.4h // ...............................~...............................................'.................................*...............................................'.................................~.............................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................~..................................................'..............................*..................................................'..............................~................................................. + // smlal v8.4s, v4.4h, v7.4h // ...................................~...........................................'.....................................*...........................................'.....................................~.......................................... + // smlal2 v10.4s, v4.8h, v7.8h // ................................~..............................................'..................................*..............................................'..................................~............................................. + // smlal v9.4s, v3.4h, v6.4h // ..............................~................................................'................................*................................................'................................~............................................... + // smlal2 v11.4s, v3.8h, v6.8h // .............................~.................................................'...............................*.................................................'...............................~................................................ + // smlal v9.4s, v4.4h, v5.4h // ..................................~............................................'....................................*............................................'....................................~........................................... + // smlal2 v11.4s, v4.8h, v5.8h // .................................~.............................................'...................................*.............................................'...................................~............................................ + // uzp1 v28.8h, v8.8h, v10.8h // ..............................................~................................'................................................*................................'................................................~............................... + // mul v28.8h, v28.8h, v2.8h // ................................................~..............................'..................................................*..............................'..................................................~............................. + // smlal v8.4s, v28.4h, v0.4h // ........................................................................~......'..........................................................................*......'..........................................................................~..... + // smlal2 v10.4s, v28.8h, v0.8h // .......................................................................~.......'.........................................................................*.......'.........................................................................~...... + // uzp2 v26.8h, v8.8h, v10.8h // ...........................................................................~...'.............................................................................*...'.............................................................................~.. + // uzp1 v28.8h, v9.8h, v11.8h // .....................................~.........................................'.......................................*.........................................'.......................................~........................................ + // mul v28.8h, v28.8h, v2.8h // .......................................~.......................................'.........................................*.......................................'.........................................~...................................... + // smlal v9.4s, v28.4h, v0.4h // ...........................................~...................................'.............................................*...................................'.............................................~.................................. + // smlal2 v11.4s, v28.8h, v0.8h // ............................................~..................................'..............................................*..................................'..............................................~................................. + // uzp2 v27.8h, v9.8h, v11.8h // .................................................~.............................'...................................................*.............................'...................................................~............................ + // zip1 v12.8h, v26.8h, v27.8h // ..............................................................................~'................................................................................*'................................................................................ + // zip2 v13.8h, v26.8h, v27.8h // ...............................................~...............................'.................................................~...............................'.................................................l.............................. + // str q12, [x0], #32 // .............................................................................~.'...............................................................................~.'...............................................................................l + // str q13, [x0, #-16] // ...................................................~...........................'.....................................................~...........................'.....................................................l.......................... + + sub count, count, #1 + cbnz count, 1b + // Instructions: 50 + // Expected cycles: 56 + // Expected IPC: 0.89 + // + // Cycle bound: 56.0 + // IPC bound: 0.89 + // + // Wall time: 4.16s + // User time: 4.16s + // + // --------------- original position ---------------> + // 0 25 + // |------------------------| + smull2 v17.4S, v31.8H, v19.8H // ..*............................................... + uzp2 v1.8H, v14.8H, v6.8H // ................*................................. + smull v18.4S, v31.4H, v21.4H // .......*.......................................... + smlal2 v24.4S, v26.8H, v28.8H // *................................................. + smlal2 v17.4S, v16.8H, v23.8H // ....*............................................. + smull v21.4S, v31.4H, v19.4H // .....*............................................ + smlal v18.4S, v16.4H, v19.4H // .........*........................................ + uzp2 v31.8H, v4.8H, v3.8H // .*................................................ + uzp1 v3.8H, v14.8H, v6.8H // ............*..................................... + smlal v21.4S, v16.4H, v23.4H // ..........*....................................... + smlal v18.4S, v20.4H, v27.4H // ...........*...................................... + uzp2 v14.8H, v29.8H, v25.8H // ...*.............................................. + smlal2 v17.4S, v20.8H, v28.8H // ......*........................................... + smlal v21.4S, v20.4H, v28.4H // .............*.................................... + smlal v18.4S, v26.4H, v28.4H // ..............*................................... + smlal2 v24.4S, v9.8H, v1.8H // ..................*............................... + smlal2 v17.4S, v26.8H, v8.8H // ........*......................................... + smlal v21.4S, v26.4H, v8.4H // ...............*.................................. + smlal v18.4S, v9.4H, v1.4H // ...................*.............................. + smlal2 v24.4S, v31.8H, v3.8H // ......................*........................... + smlal2 v17.4S, v9.8H, v3.8H // .................*................................ + smlal v21.4S, v9.4H, v3.4H // ....................*............................. + smlal v18.4S, v31.4H, v3.4H // .......................*.......................... + smlal2 v24.4S, v30.8H, v14.8H // ..........................*....................... + smlal2 v17.4S, v31.8H, v12.8H // .....................*............................ + smlal v21.4S, v31.4H, v12.4H // ........................*......................... + smlal v18.4S, v30.4H, v14.4H // ...........................*...................... + smlal2 v24.4S, v11.8H, v10.8H // ..............................*................... + smlal2 v17.4S, v30.8H, v10.8H // .........................*........................ + smlal v21.4S, v30.4H, v10.4H // ............................*..................... + smlal v18.4S, v11.4H, v10.4H // ...............................*.................. + zip2 v19.8H, v7.8H, v15.8H // ......................................*........... + smlal2 v17.4S, v11.8H, v22.8H // .............................*.................... + smlal v21.4S, v11.4H, v22.4H // ................................*................. + uzp1 v23.8H, v18.8H, v24.8H // .................................*................ + str q19, [x0, #16] // .........................................*........ + mul v19.8H, v23.8H, v2.8H // ..................................*............... + uzp1 v23.8H, v21.8H, v17.8H // .....................................*............ + str q5, [x0], #32 // .............................................*.... + mul v26.8H, v23.8H, v2.8H // .......................................*.......... + smlal v18.4S, v19.4H, v0.4H // ...................................*.............. + smlal2 v24.4S, v19.8H, v0.8H // ....................................*............. + smlal v21.4S, v26.4H, v0.4H // ...........................................*...... + smlal2 v17.4S, v26.8H, v0.8H // ..........................................*....... + uzp2 v13.8H, v18.8H, v24.8H // ........................................*......... + uzp2 v19.8H, v21.8H, v17.8H // ............................................*..... + zip1 v23.8H, v19.8H, v13.8H // ..............................................*... + zip2 v19.8H, v19.8H, v13.8H // ...............................................*.. + str q23, [x0], #32 // .................................................* + str q19, [x0, #-16] // ................................................*. + + // ----------------- new position ------------------> + // 0 25 + // |------------------------|------------------------ + // smlal2 v24.4S, v26.8H, v28.8H // ...*.............................................. + // uzp2 v4.8H, v4.8H, v3.8H // .......*.......................................... + // smull2 v13.4S, v31.8H, v19.8H // *................................................. + // uzp2 v1.8H, v29.8H, v25.8H // ...........*...................................... + // smlal2 v13.4S, v16.8H, v23.8H // ....*............................................. + // smull v18.4S, v31.4H, v19.4H // .....*............................................ + // smlal2 v13.4S, v20.8H, v28.8H // ............*..................................... + // smull v29.4S, v31.4H, v21.4H // ..*............................................... + // smlal2 v13.4S, v26.8H, v8.8H // ................*................................. + // smlal v29.4S, v16.4H, v19.4H // ......*........................................... + // smlal v18.4S, v16.4H, v23.4H // .........*........................................ + // smlal v29.4S, v20.4H, v27.4H // ..........*....................................... + // uzp1 v31.8H, v14.8H, v6.8H // ........*......................................... + // smlal v18.4S, v20.4H, v28.4H // .............*.................................... + // smlal v29.4S, v26.4H, v28.4H // ..............*................................... + // smlal v18.4S, v26.4H, v8.4H // .................*................................ + // uzp2 v26.8H, v14.8H, v6.8H // .*................................................ + // smlal2 v13.4S, v9.8H, v31.8H // ....................*............................. + // smlal2 v24.4S, v9.8H, v26.8H // ...............*.................................. + // smlal v29.4S, v9.4H, v26.4H // ..................*............................... + // smlal v18.4S, v9.4H, v31.4H // .....................*............................ + // smlal2 v13.4S, v4.8H, v12.8H // ........................*......................... + // smlal2 v24.4S, v4.8H, v31.8H // ...................*.............................. + // smlal v29.4S, v4.4H, v31.4H // ......................*........................... + // smlal v18.4S, v4.4H, v12.4H // .........................*........................ + // smlal2 v13.4S, v30.8H, v10.8H // ............................*..................... + // smlal2 v24.4S, v30.8H, v1.8H // .......................*.......................... + // smlal v29.4S, v30.4H, v1.4H // ..........................*....................... + // smlal v18.4S, v30.4H, v10.4H // .............................*.................... + // smlal2 v13.4S, v11.8H, v22.8H // ................................*................. + // smlal2 v24.4S, v11.8H, v10.8H // ...........................*...................... + // smlal v29.4S, v11.4H, v10.4H // ..............................*................... + // smlal v18.4S, v11.4H, v22.4H // .................................*................ + // uzp1 v31.8H, v29.8H, v24.8H // ..................................*............... + // mul v19.8H, v31.8H, v2.8H // ....................................*............. + // smlal v29.4S, v19.4H, v0.4H // ........................................*......... + // smlal2 v24.4S, v19.8H, v0.8H // .........................................*........ + // uzp1 v26.8H, v18.8H, v13.8H // .....................................*............ + // zip2 v14.8H, v7.8H, v15.8H // ...............................*.................. + // mul v23.8H, v26.8H, v2.8H // .......................................*.......... + // uzp2 v15.8H, v29.8H, v24.8H // ............................................*..... + // str q14, [x0, #16] // ...................................*.............. + // smlal2 v13.4S, v23.8H, v0.8H // ...........................................*...... + // smlal v18.4S, v23.4H, v0.4H // ..........................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // .............................................*.... + // str q5, [x0], #32 // ......................................*........... + // zip1 v5.8H, v7.8H, v15.8H // ..............................................*... + // zip2 v14.8H, v7.8H, v15.8H // ...............................................*.. + // str q14, [x0, #16] // .................................................* + // str q5, [x0], #32 // ................................................*. + + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_asm_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_asm_clean.S new file mode 100644 index 0000000000..722dc0f49e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_asm_clean.S @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/************************************************* + * Name: rej_uniform_asm_clean + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer of MLKEM_N + * 16-bit coefficients. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * Must be a multiple of 24. + * + * Returns number of sampled 16-bit integers (at most MLKEM_N). + **************************************************/ +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// We save the output on the stack first, and copy to the actual +// output buffer only in the end. This is because the main loop can overwrite +// by up to 62 bytes, which we account for here (we use 64 bytes for alignment). +#define STACK_SIZE (2*MLKEM_N + 64) +#define STACK_OFFSET_TMP_OUTPUT 0 + +.macro push_stack + sub sp, sp, #STACK_SIZE +.endm + +.macro pop_stack + add sp, sp, #STACK_SIZE +.endm + + /* Parameters */ + output .req x0 + buf .req x1 + buflen .req w2 + table_idx .req x3 + + len .req w4 + + /* Temporary output on the stack */ + output_tmp .req x7 + output_tmp_base .req x8 + + /* Number of coefficients sampled so far */ + count .req w9 + buf_consumed .req w10 + + /* Temporary registers */ + tmp .req w11 + final_copy_count .req w11 + + rec_idx_0 .req w12 + rec_idx_1 .req w13 + rec_idx_2 .req w14 + rec_idx_3 .req w15 + + ctr0 .req w12 + ctr1 .req w13 + ctr2 .req w14 + ctr3 .req w15 + + ctr01 .req ctr0 + ctr23 .req ctr2 + + /* Vector registers */ + + buf0 .req v0 + buf1 .req v1 + buf2 .req v2 + + tmp0 .req v4 + tmp1 .req v5 + tmp2 .req v6 + tmp3 .req v7 + + sign0 .req v4 + sign1 .req v5 + sign2 .req v6 + sign3 .req v7 + + val0 .req v16 + val0q .req q16 + val1 .req v17 + val1q .req q17 + val2 .req v18 + val2q .req q18 + val3 .req v19 + val3q .req q19 + + t0 .req s20 + t1 .req s21 + t2 .req s22 + t3 .req s23 + + table0 .req v24 + table0q .req q24 + table1 .req v25 + table1q .req q25 + table2 .req v26 + table2q .req q26 + table3 .req v27 + table3q .req q27 + + mlkem_q .req v30 + bits .req v31 + bits_q .req q31 + +.text +/* Literal pool */ +.p2align 4 +c_bit_table: + .short 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80 + +.align 4 +.global MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean) +MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean): + push_stack + + ldr bits_q, c_bit_table + movz tmp, #MLKEM_Q + dup mlkem_q.8h, tmp + + add output_tmp_base, sp, #STACK_OFFSET_TMP_OUTPUT + mov output_tmp, output_tmp_base + + mov count, #0 + mov len, #MLKEM_N + + cmp buflen, #48 + b.lo loop48_end + +loop48: + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + // First, we unpack the byte stream into a stream of signed + // coefficients, interpreting each consecutive 3 bytes as two + // signed 12-bit coefficients, presented as 16-bit integers. + // + // We handle 16 such triples a time, and use ld3 for the required + // de-interleaving of the byte stream. + sub buflen, buflen, #48 + ld3 {buf0.16b, buf1.16b, buf2.16b}, [buf], #48 + + // Unpack 16 triples of bytes into 16 pairs of 16-bit integers, + // represented as 4 vectors val0-val3. + zip1 tmp0.16b, buf0.16b, buf1.16b + zip2 tmp1.16b, buf0.16b, buf1.16b + zip1 tmp2.16b, buf1.16b, buf2.16b + zip2 tmp3.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + bic tmp1.8h, #0xf0, lsl 8 + ushr tmp2.8h, tmp2.8h, #4 + ushr tmp3.8h, tmp3.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp2.8h + zip2 val1.8h, tmp0.8h, tmp2.8h + zip1 val2.8h, tmp1.8h, tmp3.8h + zip2 val3.8h, tmp1.8h, tmp3.8h + + // At this point, val0-val3 are the signed integers to do rejection + // sampling on. For each of them, do the following: + // - Check which coefficients are within range, and represent the set + // of lane-indices of those coefficients as an 8-bit bitmap. + // - Move the respective lanes to the front of the vector. This is the + // most complex part, and is done by interpreting the 8-bit bitmap as + // an index into a lookup table giving the lane-table to be use for + // the `tbl` instruction. + // - Write the vector to the output buffer, but merely increase the output + // buffer pointer by the number of valid coefficients. + + // Set valid lanes to -1 (0b1...1) + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + cmhi sign2.8h, mlkem_q.8h, val2.8h + cmhi sign3.8h, mlkem_q.8h, val3.8h + + // If lane i is valid and has value -1, retain only i-th bit + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + and sign2.16b, sign2.16b, bits.16b + and sign3.16b, sign3.16b, bits.16b + + // Get 8-bit bitmap of valid lane indices by adding lanes + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + fmov rec_idx_2, t2 + fmov rec_idx_3, t3 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + ldr table2q, [table_idx, rec_idx_2, uxtw #4] + ldr table3q, [table_idx, rec_idx_3, uxtw #4] + + // Compute number of valid coefficients. Recall that at this + // point, lane i has value 2^i (hence popcount 1) if its coefficient + // is valid, and 0 otherwise. + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + cnt sign2.16b, sign2.16b + cnt sign3.16b, sign3.16b + + // Extract number of valid coefficients + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov ctr0, t0 + fmov ctr1, t1 + fmov ctr2, t2 + fmov ctr3, t3 + + // Move valid coefficients to the front + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + tbl val2.16b, {val2.16b}, table2.16b + tbl val3.16b, {val3.16b}, table3.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + str val2q, [output_tmp] + add output_tmp, output_tmp, ctr2, uxtw #1 + + str val3q, [output_tmp] + add output_tmp, output_tmp, ctr3, uxtw #1 + + add ctr01, ctr0, ctr1 + add ctr23, ctr2, ctr3 + add count, count, ctr01 + add count, count, ctr23 + + cmp buflen, #48 + b.hs loop48 +loop48_end: + + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + cmp buflen, #24 + b.lo memory_copy + + sub buflen, buflen, #24 + ld3 {buf0.8b, buf1.8b, buf2.8b}, [buf], #24 + + zip1 tmp0.16b, buf0.16b, buf1.16b + zip1 tmp1.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + ushr tmp1.8h, tmp1.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp1.8h + zip2 val1.8h, tmp0.8h, tmp1.8h + + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov ctr0, t0 + fmov ctr1, t1 + + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + add count, count, ctr0 + add count, count, ctr1 + +memory_copy: + // min = min(count,len) + cmp count, len + csel count, count, len, lo + + // Always copy MLKEM_N coefficients from the stack to the destination, + // even if not all of them may be valid. This simplifies the loop and + // allows us to stick to vectorized code. + mov final_copy_count, #0 + mov output_tmp, output_tmp_base +final_copy: + ldr val0q, [output_tmp], #64 + ldr val1q, [output_tmp, #-48] + ldr val2q, [output_tmp, #-32] + ldr val3q, [output_tmp, #-16] + str val0q, [output], #64 + str val1q, [output, #-48] + str val2q, [output, #-32] + str val3q, [output, #-16] + add final_copy_count, final_copy_count, #32 + cmp final_copy_count, #MLKEM_N + b.lt final_copy + + mov w0, count + b return + +return: + pop_stack + ret + +#endif /* defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_table.c new file mode 100644 index 0000000000..507660349d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/aarch64/src/rej_uniform_table.c @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 0 */, + 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 1 */, + 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 2 */, + 0, 1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 3 */, + 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 4 */, + 0, 1, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 5 */, + 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 6 */, + 0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 7 */, + 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 8 */, + 0, 1, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 9 */, + 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 10 */, + 0, 1, 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 11 */, + 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 12 */, + 0, 1, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 13 */, + 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 14 */, + 0, 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1 /* 15 */, + 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 16 */, + 0, 1, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 17 */, + 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 18 */, + 0, 1, 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 19 */, + 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 20 */, + 0, 1, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 21 */, + 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 22 */, + 0, 1, 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 23 */, + 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 24 */, + 0, 1, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 25 */, + 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 26 */, + 0, 1, 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 27 */, + 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 28 */, + 0, 1, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 29 */, + 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 30 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1 /* 31 */, + 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 32 */, + 0, 1, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 33 */, + 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 34 */, + 0, 1, 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 35 */, + 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 36 */, + 0, 1, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 37 */, + 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 38 */, + 0, 1, 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 39 */, + 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 40 */, + 0, 1, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 41 */, + 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 42 */, + 0, 1, 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 43 */, + 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 44 */, + 0, 1, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 45 */, + 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 46 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1 /* 47 */, + 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 48 */, + 0, 1, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 49 */, + 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 50 */, + 0, 1, 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 51 */, + 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 52 */, + 0, 1, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 53 */, + 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 54 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 55 */, + 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 56 */, + 0, 1, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 57 */, + 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 58 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 59 */, + 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 60 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 61 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 62 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1 /* 63 */, + 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 64 */, + 0, 1, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 65 */, + 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 66 */, + 0, 1, 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 67 */, + 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 68 */, + 0, 1, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 69 */, + 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 70 */, + 0, 1, 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 71 */, + 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 72 */, + 0, 1, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 73 */, + 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 74 */, + 0, 1, 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 75 */, + 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 76 */, + 0, 1, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 77 */, + 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 78 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1 /* 79 */, + 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 80 */, + 0, 1, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 81 */, + 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 82 */, + 0, 1, 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 83 */, + 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 84 */, + 0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 85 */, + 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 86 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 87 */, + 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 88 */, + 0, 1, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 89 */, + 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 90 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 91 */, + 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 92 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 93 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 94 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1 /* 95 */, + 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 96 */, + 0, 1, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 97 */, + 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 98 */, + 0, 1, 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 99 */, + 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 100 */, + 0, 1, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 101 */, + 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 102 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 103 */, + 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 104 */, + 0, 1, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 105 */, + 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 106 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 107 */, + 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 108 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 109 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 110 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1 /* 111 */, + 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 112 */, + 0, 1, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 113 */, + 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 114 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 115 */, + 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 116 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 117 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 118 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 119 */, + 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 120 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 121 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 122 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 123 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 124 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 125 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 126 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1 /* 127 */, + 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 128 */, + 0, 1, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 129 */, + 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 130 */, + 0, 1, 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 131 */, + 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 132 */, + 0, 1, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 133 */, + 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 134 */, + 0, 1, 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 135 */, + 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 136 */, + 0, 1, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 137 */, + 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 138 */, + 0, 1, 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 139 */, + 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 140 */, + 0, 1, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 141 */, + 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 142 */, + 0, 1, 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1 /* 143 */, + 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 144 */, + 0, 1, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 145 */, + 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 146 */, + 0, 1, 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 147 */, + 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 148 */, + 0, 1, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 149 */, + 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 150 */, + 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 151 */, + 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 152 */, + 0, 1, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 153 */, + 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 154 */, + 0, 1, 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 155 */, + 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 156 */, + 0, 1, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 157 */, + 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 158 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1 /* 159 */, + 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 160 */, + 0, 1, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 161 */, + 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 162 */, + 0, 1, 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 163 */, + 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 164 */, + 0, 1, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 165 */, + 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 166 */, + 0, 1, 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 167 */, + 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 168 */, + 0, 1, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 169 */, + 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 170 */, + 0, 1, 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 171 */, + 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 172 */, + 0, 1, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 173 */, + 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 174 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1 /* 175 */, + 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 176 */, + 0, 1, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 177 */, + 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 178 */, + 0, 1, 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 179 */, + 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 180 */, + 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 181 */, + 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 182 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 183 */, + 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 184 */, + 0, 1, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 185 */, + 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 186 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 187 */, + 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 188 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 189 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 190 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1 /* 191 */, + 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 192 */, + 0, 1, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 193 */, + 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 194 */, + 0, 1, 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 195 */, + 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 196 */, + 0, 1, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 197 */, + 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 198 */, + 0, 1, 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 199 */, + 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 200 */, + 0, 1, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 201 */, + 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 202 */, + 0, 1, 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 203 */, + 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 204 */, + 0, 1, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 205 */, + 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 206 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1 /* 207 */, + 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 208 */, + 0, 1, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 209 */, + 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 210 */, + 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 211 */, + 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 212 */, + 0, 1, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 213 */, + 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 214 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 215 */, + 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 216 */, + 0, 1, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 217 */, + 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 218 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 219 */, + 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 220 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 221 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 222 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1 /* 223 */, + 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 224 */, + 0, 1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 225 */, + 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 226 */, + 0, 1, 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 227 */, + 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 228 */, + 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 229 */, + 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 230 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 231 */, + 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 232 */, + 0, 1, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 233 */, + 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 234 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 235 */, + 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 236 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 237 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 238 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1 /* 239 */, + 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 240 */, + 0, 1, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 241 */, + 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 242 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 243 */, + 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 244 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 245 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 246 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 247 */, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 248 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 249 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 250 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 251 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 252 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 253 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 254 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 /* 255 */, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_aarch64_rej_uniform_table) +int empty_cu_aarch64_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_aarch64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-512_ref/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_ref/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/LICENSE similarity index 100% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/LICENSE rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/LICENSE diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/README.md new file mode 100644 index 0000000000..2073425c3b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/README.md @@ -0,0 +1,4 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +This directory contains the native x86_64 arithmetic backend for ML-KEM provided by the official [AVX2 +implementation](https://github.com/pq-crystals/kyber/tree/main/avx2) of the Kyber team. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/default.h new file mode 100644 index 0000000000..592e8996dc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/default.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME X86_64_DEFAULT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "x86_64/src/default_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/align.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/align.h new file mode 100644 index 0000000000..42a02fe57c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/align.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/align.h + */ + +#ifndef ALIGN_H +#define ALIGN_H + +#include +#include + +#define ALIGNED_UINT8(N) \ + union \ + { \ + uint8_t coeffs[N]; \ + __m256i vec[(N + 31) / 32]; \ + } + +#define ALIGNED_INT16(N) \ + union \ + { \ + int16_t coeffs[N]; \ + __m256i vec[(N + 15) / 16]; \ + } + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/arith_native_x86_64.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/arith_native_x86_64.h new file mode 100644 index 0000000000..ce13e7911f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/arith_native_x86_64.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_X86_64_NATIVE_H +#define MLKEM_X86_64_NATIVE_H + +#include "common.h" + +#include +#include +#include "polyvec.h" +#include "consts.h" + +#define REJ_UNIFORM_AVX_NBLOCKS 3 /* See MLKEM_GEN_MATRIX_NBLOCKS */ +#define REJ_UNIFORM_AVX_BUFLEN \ + (3 * 168) /* REJ_UNIFORM_AVX_BUFLEN * SHAKE128_RATE */ + +#define rej_uniform_avx2 MLKEM_NAMESPACE(rej_uniform_avx2) +unsigned int rej_uniform_avx2(int16_t *r, const uint8_t *buf); + +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) +extern const uint8_t rej_uniform_table[256][8]; + +#define ntt_avx2 MLKEM_NAMESPACE(ntt_avx2) +void ntt_avx2(__m256i *r, const __m256i *qdata); + +#define invntt_avx2 MLKEM_NAMESPACE(invntt_avx2) +void invntt_avx2(__m256i *r, const __m256i *qdata); + +#define nttpack_avx2 MLKEM_NAMESPACE(nttpack_avx2) +void nttpack_avx2(__m256i *r, const __m256i *qdata); + +#define nttunpack_avx2 MLKEM_NAMESPACE(nttunpack_avx2) +void nttunpack_avx2(__m256i *r, const __m256i *qdata); + +#define reduce_avx2 MLKEM_NAMESPACE(reduce_avx2) +void reduce_avx2(__m256i *r, const __m256i *qdata); + +#define basemul_avx2 MLKEM_NAMESPACE(basemul_avx2) +void basemul_avx2(__m256i *r, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define polyvec_basemul_acc_montgomery_cached_avx2 \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_avx2) +void polyvec_basemul_acc_montgomery_cached_avx2( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); + +#define ntttobytes_avx2 MLKEM_NAMESPACE(ntttobytes_avx2) +void ntttobytes_avx2(uint8_t *r, const __m256i *a, const __m256i *qdata); + +#define nttfrombytes_avx2 MLKEM_NAMESPACE(nttfrombytes_avx2) +void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata); + +#define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2) +void tomont_avx2(__m256i *r, const __m256i *qdata); + +#endif /* MLKEM_X86_64_NATIVE_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/basemul.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.S similarity index 61% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/basemul.S rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.S index 36990639b2..b97840e702 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/basemul.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.S @@ -1,12 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" +/* Polynomials to be multiplied are denoted a+bX (rsi arg) and c+dX (rdx arg) */ .macro schoolbook off -vmovdqa _16XQINV*2(%rcx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQINV*2(%rcx),%ymm0 vmovdqa (64*\off+ 0)*2(%rsi),%ymm1 # a0 vmovdqa (64*\off+16)*2(%rsi),%ymm2 # b0 vmovdqa (64*\off+32)*2(%rsi),%ymm3 # a1 vmovdqa (64*\off+48)*2(%rsi),%ymm4 # b1 +/* Prepare Montgomery twists */ vpmullw %ymm0,%ymm1,%ymm9 # a0.lo vpmullw %ymm0,%ymm2,%ymm10 # b0.lo vpmullw %ymm0,%ymm3,%ymm11 # a1.lo @@ -15,6 +28,7 @@ vpmullw %ymm0,%ymm4,%ymm12 # b1.lo vmovdqa (64*\off+ 0)*2(%rdx),%ymm5 # c0 vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0 +/* Compute high-parts of monomials in (a0+b0*X)*(c0+d0*X) */ vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi @@ -23,6 +37,8 @@ vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1 vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1 +/* Compute high-parts of monomials in (a1+b1*X)*(c1+d1*X) */ +/* Don't yet accumulate nor reduce X^2 */ vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi @@ -30,17 +46,22 @@ vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi vmovdqa %ymm13,(%rsp) +/* Compute low-parts of monomials in (a0+b0*X)*(c0+d0*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo +/* Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo -vmovdqa _16XQ*2(%rcx),%ymm8 +/* Compute 2nd high multiplication in Montgomery multiplication */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rcx),%ymm8 vpmulhw %ymm8,%ymm13,%ymm13 vpmulhw %ymm8,%ymm9,%ymm9 vpmulhw %ymm8,%ymm5,%ymm5 @@ -50,6 +71,7 @@ vpmulhw %ymm8,%ymm11,%ymm11 vpmulhw %ymm8,%ymm7,%ymm7 vpmulhw %ymm8,%ymm12,%ymm12 +/* Finish Montgomery multiplications */ vpsubw (%rsp),%ymm13,%ymm13 # -a0c0 vpsubw %ymm9,%ymm1,%ymm9 # a0d0 vpsubw %ymm5,%ymm14,%ymm5 # b0c0 @@ -60,6 +82,10 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1 vpsubw %ymm7,%ymm0,%ymm7 # b1c1 vpsubw %ymm12,%ymm4,%ymm12 # b1d1 +/* b0*d0 and b1*d1 need twisting by a twiddle, accounting + * for X^2=zeta in F_q[X]/(X^2-zeta). + * + * TODO: This could be precomputed in the mulcache */ vmovdqa (%r9),%ymm0 vmovdqa 32(%r9),%ymm1 vpmullw %ymm0,%ymm10,%ymm2 @@ -76,6 +102,9 @@ vpaddw %ymm7,%ymm11,%ymm11 vpsubw %ymm13,%ymm10,%ymm13 vpsubw %ymm12,%ymm6,%ymm6 +/* Bounds: Since we are multiplying with signed canonical twiddles, + * each Montgomery multiplication has absolute value < q, + * and hence the coefficients of the output have absolute value < 2q. */ vmovdqa %ymm13,(64*\off+ 0)*2(%rdi) vmovdqa %ymm9,(64*\off+16)*2(%rdi) vmovdqa %ymm6,(64*\off+32)*2(%rdi) @@ -83,13 +112,13 @@ vmovdqa %ymm11,(64*\off+48)*2(%rdi) .endm .text -.global cdecl(basemul_avx) -cdecl(basemul_avx): +.global MLKEM_ASM_NAMESPACE(basemul_avx2) +MLKEM_ASM_NAMESPACE(basemul_avx2): mov %rsp,%r8 and $-32,%rsp sub $32,%rsp -lea (_ZETAS_EXP+176)*2(%rcx),%r9 +lea (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+176)*2(%rcx),%r9 schoolbook 0 add $32*2,%r9 @@ -103,3 +132,5 @@ schoolbook 3 mov %r8,%rsp ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.c new file mode 100644 index 0000000000..5f9ae99c80 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/basemul.c @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "poly.h" +#include "polyvec.h" + +#include "arith_native_x86_64.h" +#include "consts.h" + +static void poly_basemul_montgomery_avx2(poly *r, const poly *a, const poly *b) +{ + basemul_avx2((__m256i *)r->coeffs, (const __m256i *)a->coeffs, + (const __m256i *)b->coeffs, qdata.vec); +} + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ +static void poly_add_avx2(poly *r, const poly *a, const poly *b) +{ + unsigned i; + __m256i f0, f1; + + for (i = 0; i < MLKEM_N; i += 16) + { + f0 = _mm256_load_si256((const __m256i *)&a->coeffs[i]); + f1 = _mm256_load_si256((const __m256i *)&b->coeffs[i]); + f0 = _mm256_add_epi16(f0, f1); + _mm256_store_si256((__m256i *)&r->coeffs[i], f0); + } +} + +void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + /* TODO: Use mulcache for AVX2. So far, it is unused. */ + ((void)b_cache); + + /* Coefficient-wise bound of each basemul is 2q. + * Since we are accumulating at most 4 times, the + * overall bound is 8q < INT16_MAX. */ + poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_avx2(&t, &a->vec[i], &b->vec[i]); + poly_add_avx2(r, r, &t); + } +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy constant to keep compiler happy despite empty CU */ + +#define empty_cu_avx2_basemul MLKEM_NAMESPACE(empty_cu_avx2_basemul) +int empty_cu_avx2_basemul; + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.c new file mode 100644 index 0000000000..86a0835efd --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.c @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.c + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "align.h" +#include "consts.h" + +#define Q MLKEM_Q +#define MONT -1044 /* 2^16 mod q */ +#define QINV -3327 /* q^-1 mod 2^16 */ +#define V 20159 /* floor(2^26/q + 0.5) */ +#define FHI 1441 /* mont^2/128 */ +#define FLO -10079 /* qinv*FHI */ +#define MONTSQHI 1353 /* mont^2 */ +#define MONTSQLO 20553 /* qinv*MONTSQHI */ +#define MASK 4095 +#define SHIFT 32 + +const qdata_t qdata = {{ +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, + +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, + +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 + V, V, V, V, V, V, + V, V, V, V, V, V, + V, V, V, V, + +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 + 3854, 3340, 2826, 2312, 1798, 1284, + 770, 256, 3854, 3340, 2826, 2312, + 1798, 1284, 770, 256, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 + 7, 0, 6, 0, 5, 0, + 4, 0, 3, 0, 2, 0, + 1, 0, 0, 0, + +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#include "x86_64_zetas.i" + +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT}}; + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_consts MLKEM_NAMESPACE(empty_cu_consts) +int empty_cu_consts; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.h new file mode 100644 index 0000000000..00c415952e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/consts.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.h + */ + +#ifndef CONSTS_H +#define CONSTS_H + +#include "common.h" + +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + +/* The C ABI on MacOS exports all symbols with a leading + * underscore. This means that any symbols we refer to from + * C files (functions) can't be found, and all symbols we + * refer to from ASM also can't be found. + * + * This define helps us get around this + */ + +#ifndef __ASSEMBLER__ +#include "align.h" +typedef ALIGNED_INT16(640) qdata_t; +#define qdata MLKEM_NAMESPACE(qdata) +extern const qdata_t qdata; +#endif + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/default_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/default_impl.h new file mode 100644 index 0000000000..66de8c85f3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/default_impl.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include + +#include "poly.h" +#include "polyvec.h" +#include "arith_native_x86_64.h" + +#define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + +#define MLKEM_USE_NATIVE_REJ_UNIFORM +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_POLY_FROMBYTES + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +#define NTT_BOUND_NATIVE (8 * MLKEM_Q) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +{ + nttunpack_avx2((__m256i *)(data->coeffs), qdata.vec); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + /* AVX2 implementation assumes specific buffer lengths */ + if (len != MLKEM_N || buflen != REJ_UNIFORM_AVX_BUFLEN) + { + return -1; + } + + return (int)rej_uniform_avx2(r, buf); +} + +static INLINE void ntt_native(poly *data) +{ + ntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void intt_native(poly *data) +{ + invntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void poly_reduce_native(poly *data) +{ + reduce_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_tomont_native(poly *data) +{ + tomont_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + /* AVX2 backend does not use mulcache */ + ((void)y); + ((void)x); +} + +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_avx2(r, a, b, b_cache); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + ntttobytes_avx2(r, (const __m256i *)a->coeffs, qdata.vec); +} + +static INLINE void poly_frombytes_native(poly *r, + const uint8_t a[MLKEM_POLYBYTES]) +{ + nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.S similarity index 50% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.S rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.S index 3bb1ebd3d8..134bd4f710 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.S @@ -1,8 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation based on Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +// Changes: +// - Add call to csub in reduce128_avx to produce outputs +// in [0,1,...,q-1] rather than [0,1,...,q], matching the +// semantics of poly_reduce(). + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) #include "consts.h" -.include "fq.inc" + +#include "fq.inc" .text -reduce128_avx: +reduce128_avx2: #load vmovdqa (%rdi),%ymm2 vmovdqa 32(%rdi),%ymm3 @@ -22,6 +39,15 @@ red16 7 red16 8 red16 9 +csubq 2 +csubq 3 +csubq 4 +csubq 5 +csubq 6 +csubq 7 +csubq 8 +csubq 9 + #store vmovdqa %ymm2,(%rdi) vmovdqa %ymm3,32(%rdi) @@ -34,17 +60,18 @@ vmovdqa %ymm9,224(%rdi) ret -.global cdecl(reduce_avx) -cdecl(reduce_avx): +.global MLKEM_ASM_NAMESPACE(reduce_avx2) +MLKEM_ASM_NAMESPACE(reduce_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XV*2(%rsi),%ymm1 -call reduce128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XV*2(%rsi),%ymm1 +call reduce128_avx2 add $256,%rdi -call reduce128_avx +call reduce128_avx2 ret -tomont128_avx: + +tomont128_avx2: #load vmovdqa (%rdi),%ymm3 vmovdqa 32(%rdi),%ymm4 @@ -76,13 +103,15 @@ vmovdqa %ymm10,224(%rdi) ret -.global cdecl(tomont_avx) -cdecl(tomont_avx): +.global MLKEM_ASM_NAMESPACE(tomont_avx2) +MLKEM_ASM_NAMESPACE(tomont_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XMONTSQLO*2(%rsi),%ymm1 -vmovdqa _16XMONTSQHI*2(%rsi),%ymm2 -call tomont128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO*2(%rsi),%ymm1 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI*2(%rsi),%ymm2 +call tomont128_avx2 add $256,%rdi -call tomont128_avx +call tomont128_avx2 ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.inc b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.inc similarity index 67% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.inc index 4b7afc3118..76ec7a3b9e 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/fq.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/fq.inc @@ -1,3 +1,13 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + .macro red16 r,rs=0,x=12 vpmulhw %ymm1,%ymm\r,%ymm\x .if \rs @@ -22,6 +32,8 @@ vpand %ymm0,%ymm\x,%ymm\x vpaddw %ymm\x,%ymm\r,%ymm\r .endm +/* Montgomery multiplication between b and ah, + * with Montgomery twist of ah in al. */ .macro fqmulprecomp al,ah,b,x=12 vpmullw %ymm\al,%ymm\b,%ymm\x vpmulhw %ymm\ah,%ymm\b,%ymm\b diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/intt.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/intt.S new file mode 100644 index 0000000000..6b1d78ef26 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/intt.S @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* Implementation based on Kyber repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + * + * Changes to placement of modular reductions have + * been made to simplify reasoning of non-overflow */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" +#include "fq.inc" + +/* Compute four GS butterflies between rh{0,1,2,3} and rl{0,1,2,3}. + * Butterflies 0,1 use root zh0 and twisted root zl0, and butterflies + * 2,3 use root zh1 and twisted root zl1 + * Results are again in rl{0-3} and rh{0-3} */ +.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 +vpsubw %ymm\rl0,%ymm\rh0,%ymm12 /* ymm12 = rh0 - rl0 */ +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 /* rl0 = rh0 + rl0 */ +vpsubw %ymm\rl1,%ymm\rh1,%ymm13 /* ymm13 = rh1 - rl1 */ + +vpmullw %ymm\zl0,%ymm12,%ymm\rh0 /* rh0 = (rh0 - rl0) * root0_twisted */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 /* rl1 = rh1 + rh1 */ +vpsubw %ymm\rl2,%ymm\rh2,%ymm14 /* ymm14 = rh2 - rl2 */ + +vpmullw %ymm\zl0,%ymm13,%ymm\rh1 /* rh1 = (rh1 - rl1) * root0_twisted */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 /* rl2 = rh2 + rl2 */ +vpsubw %ymm\rl3,%ymm\rh3,%ymm15 /* ymm15 = rh3 - rl3 */ + +vpmullw %ymm\zl1,%ymm14,%ymm\rh2 /* rh2 = (rh2 - rl2) * root1_twisted */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 /* rl3 = rh3 + rl3 */ +vpmullw %ymm\zl1,%ymm15,%ymm\rh3 /* rh3 = (rh3 - rl3) * root1_twisted */ + +vpmulhw %ymm\zh0,%ymm12,%ymm12 /* ymm12 = (rh0 - rl0) * root0 */ +vpmulhw %ymm\zh0,%ymm13,%ymm13 /* ymm13 = (rh1 - rl1) * root0 */ + +vpmulhw %ymm\zh1,%ymm14,%ymm14 /* ymm14 = (rh2 - rl2) * root1 */ +vpmulhw %ymm\zh1,%ymm15,%ymm15 /* ymm15 = (rh3 - rl3) * root1 */ + +vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 /* rh0 = Q * [(rh0 - rl0) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 /* rh1 = Q * [(rh1 - rl1) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 /* rh2 = Q * [(rh2 - rl2) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 /* rh3 = Q * [(rh3 - rl3) * root0_twisted] */ + +vpsubw %ymm\rh0,%ymm12,%ymm\rh0 /* rh0 = montmul(rh0-rl0, root0) */ +vpsubw %ymm\rh1,%ymm13,%ymm\rh1 /* rh1 = montmul(rh1-rl1, root0) */ +vpsubw %ymm\rh2,%ymm14,%ymm\rh2 /* rh2 = montmul(rh2-rl2, root0) */ +vpsubw %ymm\rh3,%ymm15,%ymm\rh3 /* rh3 = montmul(rh3-rl3, root0) */ +.endm + +.macro intt_levels0t5 off +/* level 0 */ +/* no bounds assumptions */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFLO*2(%rsi),%ymm2 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFHI*2(%rsi),%ymm3 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +fqmulprecomp 2,3,4 +fqmulprecomp 2,3,6 +fqmulprecomp 2,3,5 +fqmulprecomp 2,3,7 + +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 + +fqmulprecomp 2,3,8 +fqmulprecomp 2,3,10 +fqmulprecomp 2,3,9 +fqmulprecomp 2,3,11 + +/* bounds: coefficients < q */ + +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm12 +vpshufb %ymm12,%ymm15,%ymm15 +vpshufb %ymm12,%ymm1,%ymm1 +vpshufb %ymm12,%ymm2,%ymm2 +vpshufb %ymm12,%ymm3,%ymm3 + +butterfly 4,5,8,9,6,7,10,11,15,1,2,3 + +/* Montgmoery multiplication with a signed canonical twiddle + * always has absolute value < q. This is used henceforth to + * normalize the absolute bounds on the second half inputs + * to the current butterfly + * + * 4,5,8,9 abs bound < 2q; 6,7,10,11 abs bound < q */ + +/* level 1 */ +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm1 +vpshufb %ymm1,%ymm2,%ymm2 +vpshufb %ymm1,%ymm3,%ymm3 + +butterfly 4,5,6,7,8,9,10,11,2,2,3,3 + +/* For 8,9,10,11, it is sufficient to use the bound INT16_MAX). */ +red16 7 +/* global abs bound < 4q */ + +vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.macro intt_level6 off +/* level 6 */ +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm2 + +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm3 + +butterfly 4,5,6,7,8,9,10,11 +/* global abs bound < 8q */ + +/* REF-CHANGE: The official AVX2 implementation has a `red16 4` for `off=0`. + * We don't need this because of the earlier red16 which ensures an 8q bound */ + +vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(invntt_avx2) +MLKEM_ASM_NAMESPACE(invntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +intt_levels0t5 0 +intt_levels0t5 1 + +intt_level6 0 +intt_level6 1 +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/ntt.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/ntt.S new file mode 100644 index 0000000000..e8bf7894b4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/ntt.S @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" + +/* Compute steps 1,2 / 3 of Montgomery multiplication */ +.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 +vpmullw %ymm\zl0,%ymm\rh0,%ymm12 +vpmullw %ymm\zl0,%ymm\rh1,%ymm13 + +vpmullw %ymm\zl1,%ymm\rh2,%ymm14 +vpmullw %ymm\zl1,%ymm\rh3,%ymm15 + +vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 +vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 + +vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 +vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 +.endm + +/* Compute step 3 / 3 of Montgomery multiplication */ +/* Multiply-high is signed; outputs are bound by 2^15 * q in abs value */ +.macro reduce +vpmulhw %ymm0,%ymm12,%ymm12 +vpmulhw %ymm0,%ymm13,%ymm13 + +vpmulhw %ymm0,%ymm14,%ymm14 +vpmulhw %ymm0,%ymm15,%ymm15 +.endm + +/* Finish Montgomery multiplication and compute add/sub steps in NTT butterfly + * + * At this point, the two high-products of 4 ongoing Montgomery multiplications + * are in %ymm{12,13,14,15} and %ymm{rh{0,1,2,3}}, respectively. + * The NTT coefficients that the results of the Montgomery multiplications should + * be add/sub-ed with, are in %ymm{rl{0,1,2,3}}. + * + * What's interesting, here, is that rather than completing the Montgomery + * multiplications by computing `%ymm{12+i} + %ymm{rh{i}}`, and then add/sub'ing + * the result into %ymm{rl{0,1,2,3}}, we add/sub both `%ymm{12+i}` and + * %ymm{rh{i}} to %ymm{rl{0,1,2,3}}, and then add the results. + * + * Functionally, though, this is still a signed Montgomery multiplication + * followed by an add/sub. + * + * Since the result of the Montgomery multiplication is bounded + * by q in absolute value, the coefficients overall grow by not + * more than q in absolute value per layer. */ +.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln /* rln = rl0 + rh0 */ +vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 /* rh0 = rl0 - rh0 */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 /* rl0 = rl1 + rh1 */ +vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 /* rh1 = rl1 - rh1 */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 /* rl1 = rl2 + rh2 */ +vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 /* rh2 = rl2 - rh2 */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 /* rl2 = rl3 + rh3 */ +vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 /* rh3 = rl3 - rh3 */ + +vpsubw %ymm12,%ymm\rln,%ymm\rln /* rln = rh0 + rl0 - ymm12 = rl0 + (rh0 - ymm12) */ +vpaddw %ymm12,%ymm\rh0,%ymm\rh0 /* rh0 = rl0 - rh0 + ymm12 = rl0 - (rh0 - ymm12) */ +vpsubw %ymm13,%ymm\rl0,%ymm\rl0 /* rl0 = rl1 + rh1 - ymm13 = rl1 + (rh1 - ymm13) */ +vpaddw %ymm13,%ymm\rh1,%ymm\rh1 /* rh1 = rl1 - rh1 + ymm13 = rl1 - (rh1 - ymm13) */ +vpsubw %ymm14,%ymm\rl1,%ymm\rl1 /* rl1 = rh2 + rl2 - ymm14 = rl2 + (rh2 - ymm14) */ +vpaddw %ymm14,%ymm\rh2,%ymm\rh2 /* rh2 = rl2 - rh2 + ymm14 = rl2 - (rh2 - ymm14) */ +vpsubw %ymm15,%ymm\rl2,%ymm\rl2 /* rl2 = rh3 + rl3 - ymm15 = rl3 + (rh3 - ymm15) */ +vpaddw %ymm15,%ymm\rh3,%ymm\rh3 /* rh3 = rl3 - rh3 + ymm15 = rl3 - (rh3 - ymm15) */ +.endm + +.macro level0 off +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm15 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.macro levels1t6 off +/* level 1 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +/* level 2 */ +shuffle8 5,10,7,10 +shuffle8 6,11,5,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 + +mul 7,10,5,11 + +shuffle8 3,8,6,8 +shuffle8 4,9,3,9 + +reduce +update 4,6,8,3,9,7,10,5,11 + +/* level 3 */ +shuffle4 8,5,9,5 +shuffle4 3,11,8,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 + +mul 9,5,8,11 + +shuffle4 4,7,3,7 +shuffle4 6,10,4,10 + +reduce +update 6,3,7,4,10,9,5,8,11 + +/* level 4 */ +shuffle2 7,8,10,8 +shuffle2 4,11,7,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 + +mul 10,8,7,11 + +shuffle2 6,9,4,9 +shuffle2 3,5,6,5 + +reduce +update 3,4,9,6,5,10,8,7,11 + +/* level 5 */ +shuffle1 9,7,5,7 +shuffle1 6,11,9,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 + +mul 5,7,9,11 + +shuffle1 3,10,6,10 +shuffle1 4,8,3,8 + +reduce +update 4,6,10,3,8,5,7,9,11 + +/* level 6 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 + +mul 10,3,9,11,14,15,8,2 + +reduce +update 8,4,6,5,7,10,3,9,11 + +vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(ntt_avx2) +MLKEM_ASM_NAMESPACE(ntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +level0 0 +level0 1 + +levels1t6 0 +levels1t6 1 + +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_avx2.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_avx2.c new file mode 100644 index 0000000000..54037a0df9 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_avx2.c @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include +#include +#include "arith_native_x86_64.h" +#include "consts.h" + +unsigned int rej_uniform_avx2(int16_t *RESTRICT r, const uint8_t *buf) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + uint32_t good; + const __m256i bound = + _mm256_load_si256(&qdata.vec[AVX2_BACKEND_DATA_OFFSET_16XQ / 16]); + const __m256i ones = _mm256_set1_epi8(1); + const __m256i mask = _mm256_set1_epi16(0xFFF); + const __m256i idx8 = + _mm256_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4, + 11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0); + __m256i f0, f1, g0, g1, g2, g3; + __m128i f, t, pilo, pihi; + + ctr = pos = 0; + while (ctr <= MLKEM_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 48) + { + f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); + /* Don't load from offset 24, as this would over-read the buffer */ + f1 = _mm256_loadu_si256((__m256i *)&buf[pos + 16]); + f0 = _mm256_permute4x64_epi64(f0, 0x94 /* 0b10010100 ~= (2,1,1,0) */); + f1 = _mm256_permute4x64_epi64(f1, 0xe9 /* 0x11101001 ~= (3,2,2,1) */); + f0 = _mm256_shuffle_epi8(f0, idx8); + f1 = _mm256_shuffle_epi8(f1, idx8); + g0 = _mm256_srli_epi16(f0, 4); + g1 = _mm256_srli_epi16(f1, 4); + f0 = _mm256_blend_epi16(f0, g0, 0xAA); + f1 = _mm256_blend_epi16(f1, g1, 0xAA); + f0 = _mm256_and_si256(f0, mask); + f1 = _mm256_and_si256(f1, mask); + pos += 48; + + g0 = _mm256_cmpgt_epi16(bound, f0); + g1 = _mm256_cmpgt_epi16(bound, f1); + + g0 = _mm256_packs_epi16(g0, g1); + good = _mm256_movemask_epi8(g0); + + g0 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 0) & 0xFF])); + g1 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 8) & 0xFF])); + g0 = _mm256_inserti128_si256( + g0, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 16) & 0xFF]), + 1); + g1 = _mm256_inserti128_si256( + g1, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 24) & 0xFF]), + 1); + + g2 = _mm256_add_epi8(g0, ones); + g3 = _mm256_add_epi8(g1, ones); + g0 = _mm256_unpacklo_epi8(g0, g2); + g1 = _mm256_unpacklo_epi8(g1, g3); + + f0 = _mm256_shuffle_epi8(f0, g0); + f1 = _mm256_shuffle_epi8(f1, g1); + + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); + ctr += _mm_popcnt_u32((good >> 0) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); + ctr += _mm_popcnt_u32((good >> 16) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); + ctr += _mm_popcnt_u32((good >> 8) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); + ctr += _mm_popcnt_u32((good >> 24) & 0xFF); + } + + while (ctr <= MLKEM_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 24) + { + f = _mm_loadu_si128((__m128i *)&buf[pos]); + f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); + t = _mm_srli_epi16(f, 4); + f = _mm_blend_epi16(f, t, 0xAA); + f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); + pos += 12; + + t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); + good = _mm_movemask_epi8(t); + + good = _pext_u32(good, 0x5555); + pilo = _mm_loadl_epi64((__m128i *)&rej_uniform_table[good]); + + pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); + pilo = _mm_unpacklo_epi8(pilo, pihi); + f = _mm_shuffle_epi8(f, pilo); + _mm_storeu_si128((__m128i *)&r[ctr], f); + ctr += _mm_popcnt_u32(good); + } + + while (ctr < MLKEM_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)); + pos += 3; + + if (val0 < MLKEM_Q) + r[ctr++] = val0; + if (val1 < MLKEM_Q && ctr < MLKEM_N) + r[ctr++] = val1; + } + + return ctr; +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_rej_uniform_avx2 MLKEM_NAMESPACE(empty_cu_rej_uniform_avx2) +int empty_cu_rej_uniform_avx2; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_table.c new file mode 100644 index 0000000000..9bbc47146f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/rej_uniform_table.c @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include "arith_native_x86_64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[256][8] = { + {-1, -1, -1, -1, -1, -1, -1, -1}, {0, -1, -1, -1, -1, -1, -1, -1}, + {2, -1, -1, -1, -1, -1, -1, -1}, {0, 2, -1, -1, -1, -1, -1, -1}, + {4, -1, -1, -1, -1, -1, -1, -1}, {0, 4, -1, -1, -1, -1, -1, -1}, + {2, 4, -1, -1, -1, -1, -1, -1}, {0, 2, 4, -1, -1, -1, -1, -1}, + {6, -1, -1, -1, -1, -1, -1, -1}, {0, 6, -1, -1, -1, -1, -1, -1}, + {2, 6, -1, -1, -1, -1, -1, -1}, {0, 2, 6, -1, -1, -1, -1, -1}, + {4, 6, -1, -1, -1, -1, -1, -1}, {0, 4, 6, -1, -1, -1, -1, -1}, + {2, 4, 6, -1, -1, -1, -1, -1}, {0, 2, 4, 6, -1, -1, -1, -1}, + {8, -1, -1, -1, -1, -1, -1, -1}, {0, 8, -1, -1, -1, -1, -1, -1}, + {2, 8, -1, -1, -1, -1, -1, -1}, {0, 2, 8, -1, -1, -1, -1, -1}, + {4, 8, -1, -1, -1, -1, -1, -1}, {0, 4, 8, -1, -1, -1, -1, -1}, + {2, 4, 8, -1, -1, -1, -1, -1}, {0, 2, 4, 8, -1, -1, -1, -1}, + {6, 8, -1, -1, -1, -1, -1, -1}, {0, 6, 8, -1, -1, -1, -1, -1}, + {2, 6, 8, -1, -1, -1, -1, -1}, {0, 2, 6, 8, -1, -1, -1, -1}, + {4, 6, 8, -1, -1, -1, -1, -1}, {0, 4, 6, 8, -1, -1, -1, -1}, + {2, 4, 6, 8, -1, -1, -1, -1}, {0, 2, 4, 6, 8, -1, -1, -1}, + {10, -1, -1, -1, -1, -1, -1, -1}, {0, 10, -1, -1, -1, -1, -1, -1}, + {2, 10, -1, -1, -1, -1, -1, -1}, {0, 2, 10, -1, -1, -1, -1, -1}, + {4, 10, -1, -1, -1, -1, -1, -1}, {0, 4, 10, -1, -1, -1, -1, -1}, + {2, 4, 10, -1, -1, -1, -1, -1}, {0, 2, 4, 10, -1, -1, -1, -1}, + {6, 10, -1, -1, -1, -1, -1, -1}, {0, 6, 10, -1, -1, -1, -1, -1}, + {2, 6, 10, -1, -1, -1, -1, -1}, {0, 2, 6, 10, -1, -1, -1, -1}, + {4, 6, 10, -1, -1, -1, -1, -1}, {0, 4, 6, 10, -1, -1, -1, -1}, + {2, 4, 6, 10, -1, -1, -1, -1}, {0, 2, 4, 6, 10, -1, -1, -1}, + {8, 10, -1, -1, -1, -1, -1, -1}, {0, 8, 10, -1, -1, -1, -1, -1}, + {2, 8, 10, -1, -1, -1, -1, -1}, {0, 2, 8, 10, -1, -1, -1, -1}, + {4, 8, 10, -1, -1, -1, -1, -1}, {0, 4, 8, 10, -1, -1, -1, -1}, + {2, 4, 8, 10, -1, -1, -1, -1}, {0, 2, 4, 8, 10, -1, -1, -1}, + {6, 8, 10, -1, -1, -1, -1, -1}, {0, 6, 8, 10, -1, -1, -1, -1}, + {2, 6, 8, 10, -1, -1, -1, -1}, {0, 2, 6, 8, 10, -1, -1, -1}, + {4, 6, 8, 10, -1, -1, -1, -1}, {0, 4, 6, 8, 10, -1, -1, -1}, + {2, 4, 6, 8, 10, -1, -1, -1}, {0, 2, 4, 6, 8, 10, -1, -1}, + {12, -1, -1, -1, -1, -1, -1, -1}, {0, 12, -1, -1, -1, -1, -1, -1}, + {2, 12, -1, -1, -1, -1, -1, -1}, {0, 2, 12, -1, -1, -1, -1, -1}, + {4, 12, -1, -1, -1, -1, -1, -1}, {0, 4, 12, -1, -1, -1, -1, -1}, + {2, 4, 12, -1, -1, -1, -1, -1}, {0, 2, 4, 12, -1, -1, -1, -1}, + {6, 12, -1, -1, -1, -1, -1, -1}, {0, 6, 12, -1, -1, -1, -1, -1}, + {2, 6, 12, -1, -1, -1, -1, -1}, {0, 2, 6, 12, -1, -1, -1, -1}, + {4, 6, 12, -1, -1, -1, -1, -1}, {0, 4, 6, 12, -1, -1, -1, -1}, + {2, 4, 6, 12, -1, -1, -1, -1}, {0, 2, 4, 6, 12, -1, -1, -1}, + {8, 12, -1, -1, -1, -1, -1, -1}, {0, 8, 12, -1, -1, -1, -1, -1}, + {2, 8, 12, -1, -1, -1, -1, -1}, {0, 2, 8, 12, -1, -1, -1, -1}, + {4, 8, 12, -1, -1, -1, -1, -1}, {0, 4, 8, 12, -1, -1, -1, -1}, + {2, 4, 8, 12, -1, -1, -1, -1}, {0, 2, 4, 8, 12, -1, -1, -1}, + {6, 8, 12, -1, -1, -1, -1, -1}, {0, 6, 8, 12, -1, -1, -1, -1}, + {2, 6, 8, 12, -1, -1, -1, -1}, {0, 2, 6, 8, 12, -1, -1, -1}, + {4, 6, 8, 12, -1, -1, -1, -1}, {0, 4, 6, 8, 12, -1, -1, -1}, + {2, 4, 6, 8, 12, -1, -1, -1}, {0, 2, 4, 6, 8, 12, -1, -1}, + {10, 12, -1, -1, -1, -1, -1, -1}, {0, 10, 12, -1, -1, -1, -1, -1}, + {2, 10, 12, -1, -1, -1, -1, -1}, {0, 2, 10, 12, -1, -1, -1, -1}, + {4, 10, 12, -1, -1, -1, -1, -1}, {0, 4, 10, 12, -1, -1, -1, -1}, + {2, 4, 10, 12, -1, -1, -1, -1}, {0, 2, 4, 10, 12, -1, -1, -1}, + {6, 10, 12, -1, -1, -1, -1, -1}, {0, 6, 10, 12, -1, -1, -1, -1}, + {2, 6, 10, 12, -1, -1, -1, -1}, {0, 2, 6, 10, 12, -1, -1, -1}, + {4, 6, 10, 12, -1, -1, -1, -1}, {0, 4, 6, 10, 12, -1, -1, -1}, + {2, 4, 6, 10, 12, -1, -1, -1}, {0, 2, 4, 6, 10, 12, -1, -1}, + {8, 10, 12, -1, -1, -1, -1, -1}, {0, 8, 10, 12, -1, -1, -1, -1}, + {2, 8, 10, 12, -1, -1, -1, -1}, {0, 2, 8, 10, 12, -1, -1, -1}, + {4, 8, 10, 12, -1, -1, -1, -1}, {0, 4, 8, 10, 12, -1, -1, -1}, + {2, 4, 8, 10, 12, -1, -1, -1}, {0, 2, 4, 8, 10, 12, -1, -1}, + {6, 8, 10, 12, -1, -1, -1, -1}, {0, 6, 8, 10, 12, -1, -1, -1}, + {2, 6, 8, 10, 12, -1, -1, -1}, {0, 2, 6, 8, 10, 12, -1, -1}, + {4, 6, 8, 10, 12, -1, -1, -1}, {0, 4, 6, 8, 10, 12, -1, -1}, + {2, 4, 6, 8, 10, 12, -1, -1}, {0, 2, 4, 6, 8, 10, 12, -1}, + {14, -1, -1, -1, -1, -1, -1, -1}, {0, 14, -1, -1, -1, -1, -1, -1}, + {2, 14, -1, -1, -1, -1, -1, -1}, {0, 2, 14, -1, -1, -1, -1, -1}, + {4, 14, -1, -1, -1, -1, -1, -1}, {0, 4, 14, -1, -1, -1, -1, -1}, + {2, 4, 14, -1, -1, -1, -1, -1}, {0, 2, 4, 14, -1, -1, -1, -1}, + {6, 14, -1, -1, -1, -1, -1, -1}, {0, 6, 14, -1, -1, -1, -1, -1}, + {2, 6, 14, -1, -1, -1, -1, -1}, {0, 2, 6, 14, -1, -1, -1, -1}, + {4, 6, 14, -1, -1, -1, -1, -1}, {0, 4, 6, 14, -1, -1, -1, -1}, + {2, 4, 6, 14, -1, -1, -1, -1}, {0, 2, 4, 6, 14, -1, -1, -1}, + {8, 14, -1, -1, -1, -1, -1, -1}, {0, 8, 14, -1, -1, -1, -1, -1}, + {2, 8, 14, -1, -1, -1, -1, -1}, {0, 2, 8, 14, -1, -1, -1, -1}, + {4, 8, 14, -1, -1, -1, -1, -1}, {0, 4, 8, 14, -1, -1, -1, -1}, + {2, 4, 8, 14, -1, -1, -1, -1}, {0, 2, 4, 8, 14, -1, -1, -1}, + {6, 8, 14, -1, -1, -1, -1, -1}, {0, 6, 8, 14, -1, -1, -1, -1}, + {2, 6, 8, 14, -1, -1, -1, -1}, {0, 2, 6, 8, 14, -1, -1, -1}, + {4, 6, 8, 14, -1, -1, -1, -1}, {0, 4, 6, 8, 14, -1, -1, -1}, + {2, 4, 6, 8, 14, -1, -1, -1}, {0, 2, 4, 6, 8, 14, -1, -1}, + {10, 14, -1, -1, -1, -1, -1, -1}, {0, 10, 14, -1, -1, -1, -1, -1}, + {2, 10, 14, -1, -1, -1, -1, -1}, {0, 2, 10, 14, -1, -1, -1, -1}, + {4, 10, 14, -1, -1, -1, -1, -1}, {0, 4, 10, 14, -1, -1, -1, -1}, + {2, 4, 10, 14, -1, -1, -1, -1}, {0, 2, 4, 10, 14, -1, -1, -1}, + {6, 10, 14, -1, -1, -1, -1, -1}, {0, 6, 10, 14, -1, -1, -1, -1}, + {2, 6, 10, 14, -1, -1, -1, -1}, {0, 2, 6, 10, 14, -1, -1, -1}, + {4, 6, 10, 14, -1, -1, -1, -1}, {0, 4, 6, 10, 14, -1, -1, -1}, + {2, 4, 6, 10, 14, -1, -1, -1}, {0, 2, 4, 6, 10, 14, -1, -1}, + {8, 10, 14, -1, -1, -1, -1, -1}, {0, 8, 10, 14, -1, -1, -1, -1}, + {2, 8, 10, 14, -1, -1, -1, -1}, {0, 2, 8, 10, 14, -1, -1, -1}, + {4, 8, 10, 14, -1, -1, -1, -1}, {0, 4, 8, 10, 14, -1, -1, -1}, + {2, 4, 8, 10, 14, -1, -1, -1}, {0, 2, 4, 8, 10, 14, -1, -1}, + {6, 8, 10, 14, -1, -1, -1, -1}, {0, 6, 8, 10, 14, -1, -1, -1}, + {2, 6, 8, 10, 14, -1, -1, -1}, {0, 2, 6, 8, 10, 14, -1, -1}, + {4, 6, 8, 10, 14, -1, -1, -1}, {0, 4, 6, 8, 10, 14, -1, -1}, + {2, 4, 6, 8, 10, 14, -1, -1}, {0, 2, 4, 6, 8, 10, 14, -1}, + {12, 14, -1, -1, -1, -1, -1, -1}, {0, 12, 14, -1, -1, -1, -1, -1}, + {2, 12, 14, -1, -1, -1, -1, -1}, {0, 2, 12, 14, -1, -1, -1, -1}, + {4, 12, 14, -1, -1, -1, -1, -1}, {0, 4, 12, 14, -1, -1, -1, -1}, + {2, 4, 12, 14, -1, -1, -1, -1}, {0, 2, 4, 12, 14, -1, -1, -1}, + {6, 12, 14, -1, -1, -1, -1, -1}, {0, 6, 12, 14, -1, -1, -1, -1}, + {2, 6, 12, 14, -1, -1, -1, -1}, {0, 2, 6, 12, 14, -1, -1, -1}, + {4, 6, 12, 14, -1, -1, -1, -1}, {0, 4, 6, 12, 14, -1, -1, -1}, + {2, 4, 6, 12, 14, -1, -1, -1}, {0, 2, 4, 6, 12, 14, -1, -1}, + {8, 12, 14, -1, -1, -1, -1, -1}, {0, 8, 12, 14, -1, -1, -1, -1}, + {2, 8, 12, 14, -1, -1, -1, -1}, {0, 2, 8, 12, 14, -1, -1, -1}, + {4, 8, 12, 14, -1, -1, -1, -1}, {0, 4, 8, 12, 14, -1, -1, -1}, + {2, 4, 8, 12, 14, -1, -1, -1}, {0, 2, 4, 8, 12, 14, -1, -1}, + {6, 8, 12, 14, -1, -1, -1, -1}, {0, 6, 8, 12, 14, -1, -1, -1}, + {2, 6, 8, 12, 14, -1, -1, -1}, {0, 2, 6, 8, 12, 14, -1, -1}, + {4, 6, 8, 12, 14, -1, -1, -1}, {0, 4, 6, 8, 12, 14, -1, -1}, + {2, 4, 6, 8, 12, 14, -1, -1}, {0, 2, 4, 6, 8, 12, 14, -1}, + {10, 12, 14, -1, -1, -1, -1, -1}, {0, 10, 12, 14, -1, -1, -1, -1}, + {2, 10, 12, 14, -1, -1, -1, -1}, {0, 2, 10, 12, 14, -1, -1, -1}, + {4, 10, 12, 14, -1, -1, -1, -1}, {0, 4, 10, 12, 14, -1, -1, -1}, + {2, 4, 10, 12, 14, -1, -1, -1}, {0, 2, 4, 10, 12, 14, -1, -1}, + {6, 10, 12, 14, -1, -1, -1, -1}, {0, 6, 10, 12, 14, -1, -1, -1}, + {2, 6, 10, 12, 14, -1, -1, -1}, {0, 2, 6, 10, 12, 14, -1, -1}, + {4, 6, 10, 12, 14, -1, -1, -1}, {0, 4, 6, 10, 12, 14, -1, -1}, + {2, 4, 6, 10, 12, 14, -1, -1}, {0, 2, 4, 6, 10, 12, 14, -1}, + {8, 10, 12, 14, -1, -1, -1, -1}, {0, 8, 10, 12, 14, -1, -1, -1}, + {2, 8, 10, 12, 14, -1, -1, -1}, {0, 2, 8, 10, 12, 14, -1, -1}, + {4, 8, 10, 12, 14, -1, -1, -1}, {0, 4, 8, 10, 12, 14, -1, -1}, + {2, 4, 8, 10, 12, 14, -1, -1}, {0, 2, 4, 8, 10, 12, 14, -1}, + {6, 8, 10, 12, 14, -1, -1, -1}, {0, 6, 8, 10, 12, 14, -1, -1}, + {2, 6, 8, 10, 12, 14, -1, -1}, {0, 2, 6, 8, 10, 12, 14, -1}, + {4, 6, 8, 10, 12, 14, -1, -1}, {0, 4, 6, 8, 10, 12, 14, -1}, + {2, 4, 6, 8, 10, 12, 14, -1}, {0, 2, 4, 6, 8, 10, 12, 14}, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_avx2_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_avx2_rej_uniform_table) +int empty_cu_avx2_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.S b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.S similarity index 81% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.S rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.S index 18325ebec0..5e708748a8 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.S @@ -1,9 +1,21 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" -.include "fq.inc" -.include "shuffle.inc" +#include "fq.inc" +#include "shuffle.inc" -/* -nttpack_avx: +.global MLKEM_ASM_NAMESPACE(nttpack_avx2) +MLKEM_ASM_NAMESPACE(nttpack_avx2): #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -45,10 +57,8 @@ vmovdqa %ymm5,192(%rdi) vmovdqa %ymm11,224(%rdi) ret -*/ -.text -nttunpack128_avx: +nttunpack128_avx2: #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -91,11 +101,11 @@ vmovdqa %ymm11,224(%rdi) ret -.global cdecl(nttunpack_avx) -cdecl(nttunpack_avx): -call nttunpack128_avx +.global MLKEM_ASM_NAMESPACE(nttunpack_avx2) +MLKEM_ASM_NAMESPACE(nttunpack_avx2): +call nttunpack128_avx2 add $256,%rdi -call nttunpack128_avx +call nttunpack128_avx2 ret ntttobytes128_avx: @@ -109,16 +119,6 @@ vmovdqa 160(%rsi),%ymm10 vmovdqa 192(%rsi),%ymm11 vmovdqa 224(%rsi),%ymm12 -#csubq -csubq 5,13 -csubq 6,13 -csubq 7,13 -csubq 8,13 -csubq 9,13 -csubq 10,13 -csubq 11,13 -csubq 12,13 - #bitpack vpsllw $12,%ymm6,%ymm4 vpor %ymm4,%ymm5,%ymm4 @@ -168,10 +168,10 @@ vmovdqu %ymm9,160(%rdi) ret -.global cdecl(ntttobytes_avx) -cdecl(ntttobytes_avx): +.global MLKEM_ASM_NAMESPACE(ntttobytes_avx2) +MLKEM_ASM_NAMESPACE(ntttobytes_avx2): #consts -vmovdqa _16XQ*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rdx),%ymm0 call ntttobytes128_avx add $256,%rsi add $192,%rdi @@ -244,12 +244,14 @@ vmovdqa %ymm1,224(%rdi) ret -.global cdecl(nttfrombytes_avx) -cdecl(nttfrombytes_avx): +.global MLKEM_ASM_NAMESPACE(nttfrombytes_avx2) +MLKEM_ASM_NAMESPACE(nttfrombytes_avx2): #consts -vmovdqa _16XMASK*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMASK*2(%rdx),%ymm0 call nttfrombytes128_avx add $256,%rdi add $192,%rsi call nttfrombytes128_avx ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.inc b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.inc similarity index 55% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.inc index 73e9ffe03c..359807bd25 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/shuffle.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/shuffle.inc @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + .macro shuffle8 r0,r1,r2,r3 vperm2i128 $0x20,%ymm\r1,%ymm\r0,%ymm\r2 vperm2i128 $0x31,%ymm\r1,%ymm\r0,%ymm\r3 @@ -8,12 +13,19 @@ vpunpcklqdq %ymm\r1,%ymm\r0,%ymm\r2 vpunpckhqdq %ymm\r1,%ymm\r0,%ymm\r3 .endm +/* Shuffle r0=(a0,b0,c0,d0,...), r1=(a1,b1,c1,d1,...) into */ +/* r2 = (a0,b0,a1,b1,e0,f0,e1,f1,...) */ +/* r3 = (c0,d0,c1,d1,g0,h0,g1,h1,...) */ .macro shuffle2 r0,r1,r2,r3 -#vpsllq $32,%ymm\r1,%ymm\r2 +/* r2=(a1,b1,a1,b1,e1,f1,e1,f1,...) */ vmovsldup %ymm\r1,%ymm\r2 +/* Conditional move */ +/* 0xAA = 0b10101010 */ +/* r2=(a0,b0,a1,b1,e0,f0,e1,f1,...) */ vpblendd $0xAA,%ymm\r2,%ymm\r0,%ymm\r2 +/* r0=(c0,d0,0,0,g0,h0,0,0,...) */ vpsrlq $32,%ymm\r0,%ymm\r0 -#vmovshdup %ymm\r0,%ymm\r0 +/* r3=(c0,d0,c1,d1,g0,h0,g1,h1,...) */ vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/x86_64_zetas.i b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/x86_64_zetas.i new file mode 100644 index 0000000000..26d582ee53 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/x86_64/src/x86_64_zetas.i @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +/* + * Table of zeta values used in the AVX2 NTTs + * See autogen for details. + */ + +31498, 31498, 31498, 31498, -758, -758, -758, -758, 0, 0, 0, 0, 0, 0, 0, 0, + 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, + 14745, 14745, 14745, 14745, 14745, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, -359, 13525, 13525, 13525, + 13525, 13525, 13525, 13525, 13525, -12402, -12402, -12402, -12402, -12402, + -12402, -12402, -12402, 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, -20907, -20907, -20907, + -20907, 27758, 27758, 27758, 27758, -3799, -3799, -3799, -3799, -15690, + -15690, -15690, -15690, -171, -171, -171, -171, 622, 622, 622, 622, 1577, + 1577, 1577, 1577, 182, 182, 182, 182, -5827, -5827, 17363, 17363, -26360, + -26360, -29057, -29057, 5571, 5571, -1102, -1102, 21438, 21438, -26242, + -26242, 573, 573, -1325, -1325, 264, 264, 383, 383, -829, -829, 1458, 1458, + -1602, -1602, -130, -130, -5689, -6516, 1496, 30967, -23565, 20179, 20710, + 25080, -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, 1223, 652, + -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, -1618, -1162, + 126, 1469, -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, 32010, + 10631, 29175, -28762, -18486, 17560, -14430, -5276, -1103, 555, -1251, 1550, + 422, 177, -291, 1574, -246, 1159, -777, -602, -1590, -872, 418, -156, 11182, + 13387, -14233, -21655, 13131, -4587, 23092, 5493, -32502, 30317, -18741, + 12639, 20100, 18525, 19529, -12619, 430, 843, 871, 105, 587, -235, -460, + 1653, 778, -147, 1483, 1119, 644, 349, 329, -75, 787, 787, 787, 787, 787, + 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, + -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, 287, 287, + 287, 287, 287, 287, 287, 287, 202, 202, 202, 202, 202, 202, 202, 202, 10690, + 10690, 10690, 10690, 1358, 1358, 1358, 1358, -11202, -11202, -11202, -11202, + 31164, 31164, 31164, 31164, 962, 962, 962, 962, -1202, -1202, -1202, -1202, + -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, -28073, -28073, 24313, + 24313, -10532, -10532, 8800, 8800, 18426, 18426, 8859, 8859, 26675, 26675, + -16163, -16163, -681, -681, 1017, 1017, 732, 732, 608, 608, -1542, -1542, + 411, 411, -205, -205, -1571, -1571, 19883, -28250, -15887, -8898, -28309, + 9075, -30199, 18249, 13426, 14017, -29156, -12757, 16832, 4311, -24155, + -17915, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961, -1508, + -725, 448, -1065, 677, -1275, -31183, 25435, -7382, 24391, -20927, 10946, + 24214, 16989, 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, 817, + 603, 1322, -1465, -1215, 1218, -874, -1187, -1185, -1278, -1510, -870, -108, + 996, 958, 1522, 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, + -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, 1097, 610, + -1285, 384, -136, -1335, 220, -1659, -1530, 794, -854, 478, -308, 991, + -1460, 1628, diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-512_x86_64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/LICENSE new file mode 100644 index 0000000000..7922ab8007 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/LICENSE @@ -0,0 +1,6 @@ +Public Domain (https://creativecommons.org/share-your-work/public-domain/cc0/); +or Apache 2.0 License (https://www.apache.org/licenses/LICENSE-2.0.html). + +For Keccak and AES we are using public-domain +code from sources and by authors listed in +comments on top of the respective files. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/README.md new file mode 100644 index 0000000000..e499a4a229 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/README.md @@ -0,0 +1,19 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# AArch64 backend (little endian) + +This directory contains a native backend for little endian AArch64 systems. It is derived from the following research +works: + +- _Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1_, Hanno Becker, Vincent Hwang, Matthias + J. Kannwischer, Bo-Yin Yang, and Shang-Yi Yang, [https://eprint.iacr.org/2021/986](https://eprint.iacr.org/2021/986) +- _Fast and Clean: Auditable high-performance assembly via constraint solving_, Amin Abdulrahman, Hanno Becker, Matthias + J. Kannwischer, Fabien Klein, [https://eprint.iacr.org/2022/1303](https://eprint.iacr.org/2022/1303) + +## Profiles + +This backend comes with two profiles: "clean" and optimized. The "clean" backend is handwritten and meant to be easy to +read and modify; for example, is heavily leverages register aliases and assembly macros. The optimized profile is +automatically generated from the clean profile via [SLOTHY](https://github.com/slothy-optimizer/slothy). Currently, the +target architecture is Cortex-A55, but you can easily re-optimize the code for a different microarchitecture supported +by SLOTHY, by adjusting the parameters in [optimize.sh](src/optimize.sh). diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/clean.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/clean.h new file mode 100644 index 0000000000..43a401dfc4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/clean.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_CLEAN + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/clean_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/opt.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/opt.h new file mode 100644 index 0000000000..04323c3e79 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/opt.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME AARCH64_OPT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "aarch64/src/opt_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/aarch64_zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/aarch64_zetas.c new file mode 100644 index 0000000000..1e189fd995 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/aarch64_zetas.c @@ -0,0 +1,175 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Table of zeta values used in the AArch64 forward NTT + * See autogen for details. + */ +ALIGN const int16_t aarch64_ntt_zetas_layer01234[] = { + -1600, -15749, -749, -7373, -40, -394, -687, -6762, 630, 6201, + -1432, -14095, 848, 8347, 0, 0, 1062, 10453, 296, 2914, + -882, -8682, 0, 0, -1410, -13879, 1339, 13180, 1476, 14529, + 0, 0, 193, 1900, -283, -2786, 56, 551, 0, 0, + 797, 7845, -1089, -10719, 1333, 13121, 0, 0, -543, -5345, + 1426, 14036, -1235, -12156, 0, 0, -69, -679, 535, 5266, + -447, -4400, 0, 0, 569, 5601, -936, -9213, -450, -4429, + 0, 0, -1583, -15582, -1355, -13338, 821, 8081, 0, 0, +}; + +ALIGN const int16_t aarch64_ntt_zetas_layer56[] = { + 289, 289, 331, 331, -76, -76, -1573, -1573, 2845, + 2845, 3258, 3258, -748, -748, -15483, -15483, 17, 17, + 583, 583, 1637, 1637, -1041, -1041, 167, 167, 5739, + 5739, 16113, 16113, -10247, -10247, -568, -568, -680, -680, + 723, 723, 1100, 1100, -5591, -5591, -6693, -6693, 7117, + 7117, 10828, 10828, 1197, 1197, -1025, -1025, -1052, -1052, + -1274, -1274, 11782, 11782, -10089, -10089, -10355, -10355, -12540, + -12540, 1409, 1409, -48, -48, 756, 756, -314, -314, + 13869, 13869, -472, -472, 7441, 7441, -3091, -3091, -667, + -667, 233, 233, -1173, -1173, -279, -279, -6565, -6565, + 2293, 2293, -11546, -11546, -2746, -2746, 650, 650, -1352, + -1352, -816, -816, 632, 632, 6398, 6398, -13308, -13308, + -8032, -8032, 6221, 6221, -1626, -1626, -540, -540, -1482, + -1482, 1461, 1461, -16005, -16005, -5315, -5315, -14588, -14588, + 14381, 14381, 1651, 1651, -1540, -1540, 952, 952, -642, + -642, 16251, 16251, -15159, -15159, 9371, 9371, -6319, -6319, + -464, -464, 33, 33, 1320, 1320, -1414, -1414, -4567, + -4567, 325, 325, 12993, 12993, -13918, -13918, 939, 939, + -892, -892, 733, 733, 268, 268, 9243, 9243, -8780, + -8780, 7215, 7215, 2638, 2638, -1021, -1021, -941, -941, + -992, -992, 641, 641, -10050, -10050, -9262, -9262, -9764, + -9764, 6309, 6309, -1010, -1010, 1435, 1435, 807, 807, + 452, 452, -9942, -9942, 14125, 14125, 7943, 7943, 4449, + 4449, 1584, 1584, -1292, -1292, 375, 375, -1239, -1239, + 15592, 15592, -12717, -12717, 3691, 3691, -12196, -12196, -1031, + -1031, -109, -109, -780, -780, 1645, 1645, -10148, -10148, + -1073, -1073, -7678, -7678, 16192, 16192, 1438, 1438, -461, + -461, 1534, 1534, -927, -927, 14155, 14155, -4538, -4538, + 15099, 15099, -9125, -9125, 1063, 1063, -556, -556, -1230, + -1230, -863, -863, 10463, 10463, -5473, -5473, -12107, -12107, + -8495, -8495, 319, 319, 757, 757, 561, 561, -735, + -735, 3140, 3140, 7451, 7451, 5522, 5522, -7235, -7235, + -682, -682, -712, -712, 1481, 1481, 648, 648, -6713, + -6713, -7008, -7008, 14578, 14578, 6378, 6378, -525, -525, + 403, 403, 1143, 1143, -554, -554, -5168, -5168, 3967, + 3967, 11251, 11251, -5453, -5453, 1092, 1092, 1026, 1026, + -1179, -1179, 886, 886, 10749, 10749, 10099, 10099, -11605, + -11605, 8721, 8721, -855, -855, -219, -219, 1227, 1227, + 910, 910, -8416, -8416, -2156, -2156, 12078, 12078, 8957, + 8957, -1607, -1607, -1455, -1455, -1219, -1219, 885, 885, + -15818, -15818, -14322, -14322, -11999, -11999, 8711, 8711, 1212, + 1212, 1029, 1029, -394, -394, -1175, -1175, 11930, 11930, + 10129, 10129, -3878, -3878, -11566, -11566, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer01234[] = { + 1583, 15582, -821, -8081, 1355, 13338, 0, 0, -569, -5601, + 450, 4429, 936, 9213, 0, 0, 69, 679, 447, 4400, + -535, -5266, 0, 0, 543, 5345, 1235, 12156, -1426, -14036, + 0, 0, -797, -7845, -1333, -13121, 1089, 10719, 0, 0, + -193, -1900, -56, -551, 283, 2786, 0, 0, 1410, 13879, + -1476, -14529, -1339, -13180, 0, 0, -1062, -10453, 882, 8682, + -296, -2914, 0, 0, 1600, 15749, 40, 394, 749, 7373, + -848, -8347, 1432, 14095, -630, -6201, 687, 6762, 0, 0, +}; + +ALIGN const int16_t aarch64_invntt_zetas_layer56[] = { + -910, -910, -1227, -1227, 219, 219, 855, 855, -8957, + -8957, -12078, -12078, 2156, 2156, 8416, 8416, 1175, 1175, + 394, 394, -1029, -1029, -1212, -1212, 11566, 11566, 3878, + 3878, -10129, -10129, -11930, -11930, -885, -885, 1219, 1219, + 1455, 1455, 1607, 1607, -8711, -8711, 11999, 11999, 14322, + 14322, 15818, 15818, -648, -648, -1481, -1481, 712, 712, + 682, 682, -6378, -6378, -14578, -14578, 7008, 7008, 6713, + 6713, -886, -886, 1179, 1179, -1026, -1026, -1092, -1092, + -8721, -8721, 11605, 11605, -10099, -10099, -10749, -10749, 554, + 554, -1143, -1143, -403, -403, 525, 525, 5453, 5453, + -11251, -11251, -3967, -3967, 5168, 5168, 927, 927, -1534, + -1534, 461, 461, -1438, -1438, 9125, 9125, -15099, -15099, + 4538, 4538, -14155, -14155, 735, 735, -561, -561, -757, + -757, -319, -319, 7235, 7235, -5522, -5522, -7451, -7451, + -3140, -3140, 863, 863, 1230, 1230, 556, 556, -1063, + -1063, 8495, 8495, 12107, 12107, 5473, 5473, -10463, -10463, + -452, -452, -807, -807, -1435, -1435, 1010, 1010, -4449, + -4449, -7943, -7943, -14125, -14125, 9942, 9942, -1645, -1645, + 780, 780, 109, 109, 1031, 1031, -16192, -16192, 7678, + 7678, 1073, 1073, 10148, 10148, 1239, 1239, -375, -375, + 1292, 1292, -1584, -1584, 12196, 12196, -3691, -3691, 12717, + 12717, -15592, -15592, 1414, 1414, -1320, -1320, -33, -33, + 464, 464, 13918, 13918, -12993, -12993, -325, -325, 4567, + 4567, -641, -641, 992, 992, 941, 941, 1021, 1021, + -6309, -6309, 9764, 9764, 9262, 9262, 10050, 10050, -268, + -268, -733, -733, 892, 892, -939, -939, -2638, -2638, + -7215, -7215, 8780, 8780, -9243, -9243, -632, -632, 816, + 816, 1352, 1352, -650, -650, -6221, -6221, 8032, 8032, + 13308, 13308, -6398, -6398, 642, 642, -952, -952, 1540, + 1540, -1651, -1651, 6319, 6319, -9371, -9371, 15159, 15159, + -16251, -16251, -1461, -1461, 1482, 1482, 540, 540, 1626, + 1626, -14381, -14381, 14588, 14588, 5315, 5315, 16005, 16005, + 1274, 1274, 1052, 1052, 1025, 1025, -1197, -1197, 12540, + 12540, 10355, 10355, 10089, 10089, -11782, -11782, 279, 279, + 1173, 1173, -233, -233, 667, 667, 2746, 2746, 11546, + 11546, -2293, -2293, 6565, 6565, 314, 314, -756, -756, + 48, 48, -1409, -1409, 3091, 3091, -7441, -7441, 472, + 472, -13869, -13869, 1573, 1573, 76, 76, -331, -331, + -289, -289, 15483, 15483, 748, 748, -3258, -3258, -2845, + -2845, -1100, -1100, -723, -723, 680, 680, 568, 568, + -10828, -10828, -7117, -7117, 6693, 6693, 5591, 5591, 1041, + 1041, -1637, -1637, -583, -583, -17, -17, 10247, 10247, + -16113, -16113, -5739, -5739, -167, -167, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_native[] = { + 17, -17, -568, 568, 583, -583, -680, 680, 1637, -1637, 723, + -723, -1041, 1041, 1100, -1100, 1409, -1409, -667, 667, -48, 48, + 233, -233, 756, -756, -1173, 1173, -314, 314, -279, 279, -1626, + 1626, 1651, -1651, -540, 540, -1540, 1540, -1482, 1482, 952, -952, + 1461, -1461, -642, 642, 939, -939, -1021, 1021, -892, 892, -941, + 941, 733, -733, -992, 992, 268, -268, 641, -641, 1584, -1584, + -1031, 1031, -1292, 1292, -109, 109, 375, -375, -780, 780, -1239, + 1239, 1645, -1645, 1063, -1063, 319, -319, -556, 556, 757, -757, + -1230, 1230, 561, -561, -863, 863, -735, 735, -525, 525, 1092, + -1092, 403, -403, 1026, -1026, 1143, -1143, -1179, 1179, -554, 554, + 886, -886, -1607, 1607, 1212, -1212, -1455, 1455, 1029, -1029, -1219, + 1219, -394, 394, 885, -885, -1175, 1175, +}; + +ALIGN const int16_t aarch64_zetas_mulcache_twisted_native[] = { + 167, -167, -5591, 5591, 5739, -5739, -6693, 6693, 16113, + -16113, 7117, -7117, -10247, 10247, 10828, -10828, 13869, -13869, + -6565, 6565, -472, 472, 2293, -2293, 7441, -7441, -11546, + 11546, -3091, 3091, -2746, 2746, -16005, 16005, 16251, -16251, + -5315, 5315, -15159, 15159, -14588, 14588, 9371, -9371, 14381, + -14381, -6319, 6319, 9243, -9243, -10050, 10050, -8780, 8780, + -9262, 9262, 7215, -7215, -9764, 9764, 2638, -2638, 6309, + -6309, 15592, -15592, -10148, 10148, -12717, 12717, -1073, 1073, + 3691, -3691, -7678, 7678, -12196, 12196, 16192, -16192, 10463, + -10463, 3140, -3140, -5473, 5473, 7451, -7451, -12107, 12107, + 5522, -5522, -8495, 8495, -7235, 7235, -5168, 5168, 10749, + -10749, 3967, -3967, 10099, -10099, 11251, -11251, -11605, 11605, + -5453, 5453, 8721, -8721, -15818, 15818, 11930, -11930, -14322, + 14322, 10129, -10129, -11999, 11999, -3878, 3878, 8711, -8711, + -11566, 11566, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_zetas MLKEM_NAMESPACE(empty_cu_aarch64_zetas) +int empty_cu_aarch64_zetas; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/arith_native_aarch64.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/arith_native_aarch64.h new file mode 100644 index 0000000000..6a5ee8a7d6 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/arith_native_aarch64.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_AARCH64_NATIVE_H +#define MLKEM_AARCH64_NATIVE_H + +#include +#include "common.h" + +#define aarch64_ntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_ntt_zetas_layer01234) +#define aarch64_ntt_zetas_layer56 MLKEM_NAMESPACE(aarch64_ntt_zetas_layer56) +#define aarch64_invntt_zetas_layer01234 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer01234) +#define aarch64_invntt_zetas_layer56 \ + MLKEM_NAMESPACE(aarch64_invntt_zetas_layer56) +#define aarch64_zetas_mulcache_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_native) +#define aarch64_zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(aarch64_zetas_mulcache_twisted_native) +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) + +extern const int16_t aarch64_ntt_zetas_layer01234[]; +extern const int16_t aarch64_ntt_zetas_layer56[]; +extern const int16_t aarch64_invntt_zetas_layer01234[]; +extern const int16_t aarch64_invntt_zetas_layer56[]; +extern const int16_t aarch64_zetas_mulcache_native[]; +extern const int16_t aarch64_zetas_mulcache_twisted_native[]; +extern const uint8_t rej_uniform_table[]; + +#define ntt_asm_clean MLKEM_NAMESPACE(ntt_asm_clean) +void ntt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define ntt_asm_opt MLKEM_NAMESPACE(ntt_asm_opt) +void ntt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_clean MLKEM_NAMESPACE(intt_asm_clean) +void intt_asm_clean(int16_t *, const int16_t *, const int16_t *); + +#define intt_asm_opt MLKEM_NAMESPACE(intt_asm_opt) +void intt_asm_opt(int16_t *, const int16_t *, const int16_t *); + +#define rej_uniform_asm_clean MLKEM_NAMESPACE(rej_uniform_asm_clean) +unsigned int rej_uniform_asm_clean(int16_t *r, const uint8_t *buf, + unsigned int buflen, const uint8_t *table); + +#define poly_reduce_asm_clean MLKEM_NAMESPACE(poly_reduce_asm_clean) +void poly_reduce_asm_clean(int16_t *); + +#define poly_reduce_asm_opt MLKEM_NAMESPACE(poly_reduce_asm_opt) +void poly_reduce_asm_opt(int16_t *); + +#define poly_tomont_asm_clean MLKEM_NAMESPACE(poly_tomont_asm_clean) +void poly_tomont_asm_clean(int16_t *); + +#define poly_tomont_asm_opt MLKEM_NAMESPACE(poly_tomont_asm_opt) +void poly_tomont_asm_opt(int16_t *); + +#define poly_mulcache_compute_asm_clean \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_clean) +void poly_mulcache_compute_asm_clean(int16_t *, const int16_t *, + const int16_t *, const int16_t *); + + +#define poly_mulcache_compute_asm_opt \ + MLKEM_NAMESPACE(poly_mulcache_compute_asm_opt) +void poly_mulcache_compute_asm_opt(int16_t *, const int16_t *, const int16_t *, + const int16_t *); + +#define poly_tobytes_asm_clean MLKEM_NAMESPACE(poly_tobytes_asm_clean) +void poly_tobytes_asm_clean(uint8_t *r, const int16_t *a); + +#define poly_tobytes_asm_opt MLKEM_NAMESPACE(poly_tobytes_asm_opt) +void poly_tobytes_asm_opt(uint8_t *r, const int16_t *a); + +#define polyvec_basemul_acc_montgomery_cached_asm_clean \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) +void polyvec_basemul_acc_montgomery_cached_asm_clean(int16_t *r, + const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#define polyvec_basemul_acc_montgomery_cached_asm_opt \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) +void polyvec_basemul_acc_montgomery_cached_asm_opt(int16_t *r, const int16_t *a, + const int16_t *b, + const int16_t *b_cache); + +#endif /* MLKEM_AARCH64_NATIVE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/clean_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/clean_impl.h new file mode 100644 index 0000000000..b0ff3d5972 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/clean_impl.h @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +static INLINE void ntt_native(poly *data) +{ + ntt_asm_clean(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_clean(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_clean(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_clean(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_clean(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_clean( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_clean(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/consts.h new file mode 100644 index 0000000000..c40947299c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/consts.h @@ -0,0 +1,19 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_AARCH64_CONSTS) +#define MLKEM_NATIVE_AARCH64_CONSTS + +#include +#include "common.h" + +#define zetas_mulcache_native MLKEM_NAMESPACE(zetas_mulcache_native) +extern const int16_t zetas_mulcache_native[256]; + +#define zetas_mulcache_twisted_native \ + MLKEM_NAMESPACE(zetas_mulcache_twisted_native) +extern const int16_t zetas_mulcache_twisted_native[256]; + +#endif /* MLKEM_NATIVE_AARCH64_CONSTS */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_clean.S new file mode 100644 index 0000000000..623a82ae9c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_clean.S @@ -0,0 +1,364 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_clean) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_clean): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 +layer3456_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + transpose4 data // manual ld4 + + load_next_roots_56 + + // Layer 7 + gs_butterfly_v data0, data1, root1, root1_tw + gs_butterfly_v data2, data3, root2, root2_tw + // Bounds: + // data0, data2: < 2q + // data1, data3: < q + + // Layer 6 + gs_butterfly_v data0, data2, root0, root0_tw + gs_butterfly_v data1, data3, root0, root0_tw + // Bounds: + // data0: < 4q + // data1: < 2q + // data2, data3: < q + + transpose4 data + + load_next_roots_34 + + // Layer 5 + gs_butterfly data0, data1, root0, 2, 3 + gs_butterfly data2, data3, root0, 4, 5 + // Max bound: 8q + + // Not all of those reductions are needed, but the bounds tracking + // is easier if we uniformly reduce at this point. + barrett_reduce data0 + barrett_reduce data2 + barrett_reduce data1 + barrett_reduce data3 + + // Bounds: q/2 + + // Layer 4 + gs_butterfly data0, data2, root0, 0, 1 + gs_butterfly data1, data3, root0, 0, 1 + // Bounds: < q + + str q_data0, [inp], #(64) + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, layer3456_start + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + +layer012_start: + + ldr q_data0, [in, #0] + ldr q_data1, [in, #(1*(512/8))] + ldr q_data2, [in, #(2*(512/8))] + ldr q_data3, [in, #(3*(512/8))] + ldr q_data4, [in, #(4*(512/8))] + ldr q_data5, [in, #(5*(512/8))] + ldr q_data6, [in, #(6*(512/8))] + ldr q_data7, [in, #(7*(512/8))] + + gs_butterfly data0, data1, root0, 6, 7 + gs_butterfly data2, data3, root1, 0, 1 + gs_butterfly data4, data5, root1, 2, 3 + gs_butterfly data6, data7, root1, 4, 5 + + gs_butterfly data0, data2, root0, 2, 3 + gs_butterfly data1, data3, root0, 2, 3 + gs_butterfly data4, data6, root0, 4, 5 + gs_butterfly data5, data7, root0, 4, 5 + + gs_butterfly data0, data4, root0, 0, 1 + gs_butterfly data1, data5, root0, 0, 1 + gs_butterfly data2, data6, root0, 0, 1 + gs_butterfly data3, data7, root0, 0, 1 + + // Bounds: < 8q + + str q_data4, [in, #(4*(512/8))] + str q_data5, [in, #(5*(512/8))] + str q_data6, [in, #(6*(512/8))] + str q_data7, [in, #(7*(512/8))] + + str q_data0, [in], #(16) + str q_data1, [in, #(-16 + 1*(512/8))] + str q_data2, [in, #(-16 + 2*(512/8))] + str q_data3, [in, #(-16 + 3*(512/8))] + + subs count, count, #1 + cbnz count, layer012_start + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_opt.S new file mode 100644 index 0000000000..e332efef8f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/intt_opt.S @@ -0,0 +1,1020 @@ +/// Copyright (c) 2024 The mlkem-native project authors +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro gs_butterfly a, b, root, idx0, idx1 + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmodq \b, tmp, \root, \idx0, \idx1 +.endm + +.macro gs_butterfly_v a, b, root, root_twisted + sub tmp.8h, \a\().8h, \b\().8h + add \a\().8h, \a\().8h, \b\().8h + mulmod \b, tmp, \root, \root_twisted +.endm + +.macro mul_ninv dst0, dst1, dst2, dst3, src0, src1, src2, src3 + mulmod \dst0, \src0, ninv, ninv_tw + mulmod \dst1, \src1, ninv, ninv_tw + mulmod \dst2, \src2, ninv, ninv_tw + mulmod \dst3, \src3, ninv, ninv_tw +.endm + +.macro barrett_reduce a + sqdmulh t0.8h, \a\().8h, consts.h[1] + srshr t0.8h, t0.8h, #11 + mls \a\().8h, t0.8h, consts.h[0] +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro transpose_single data_out, data_in + trn1 \data_out\()0.4s, \data_in\()0.4s, \data_in\()1.4s + trn2 \data_out\()1.4s, \data_in\()0.4s, \data_in\()1.4s + trn1 \data_out\()2.4s, \data_in\()2.4s, \data_in\()3.4s + trn2 \data_out\()3.4s, \data_in\()2.4s, \data_in\()3.4s +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + +// For comparability reasons, the output range for the coefficients of this +// invNTT code is supposed to match the implementation from PQClean on commit +// ee71d2c823982bfcf54686f3cf1d666f396dc9aa. After the invNTT, the coefficients +// are NOT canonically reduced. The ordering of the coefficients is canonical, +// also matching PQClean. + +.text + + .global MLKEM_ASM_NAMESPACE(intt_asm_opt) + + in .req x0 + r01234_ptr .req x1 + r56_ptr .req x2 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + consts .req v7 + q_consts .req q7 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + ninv .req v29 + q_ninv .req q29 + ninv_tw .req v30 + q_ninv_tw .req q30 + +/* Literal pool */ +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_consts: .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 +c_ninv: dup8h 512 +c_ninv_tw: dup8h 5040 + +MLKEM_ASM_NAMESPACE(intt_asm_opt): + push_stack + + ldr q_consts, c_consts + ldr q_ninv, c_ninv + ldr q_ninv_tw, c_ninv_tw + + mov inp, in + mov count, #8 + +scale_start: + + ldr q_data0, [inp, #(16*0)] + ldr q_data1, [inp, #(16*1)] + ldr q_data2, [inp, #(16*2)] + ldr q_data3, [inp, #(16*3)] + + mul_ninv data0, data1, data2, data3, data0, data1, data2, data3 + // Bounds: Absolute value < q + + str q_data0, [inp], #64 + str q_data1, [inp, #(-64 + 16*1)] + str q_data2, [inp, #(-64 + 16*2)] + str q_data3, [inp, #(-64 + 16*3)] + + subs count, count, #1 + cbnz count, scale_start + + mov inp, in + mov count, #8 + + .p2align 2 + // Instructions: 11 + // Expected cycles: 20 + // Expected IPC: 0.55 + // + // Cycle bound: 20.0 + // IPC bound: 0.55 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x3, #0] // *............................. + ldr q8, [x3, #16] // ..*........................... + ldr q24, [x3, #32] // ....*......................... + ldr q16, [x3, #48] // ......*....................... + ldr q9, [x2], #(6*16) // ........*..................... + trn1 v0.4S, v24.4S, v16.4S // ..........*................... + ldr q6, [x2, #-80] // ...........*.................. + ldr q3, [x2, #-64] // .............*................ + ldr q15, [x2, #-48] // ...............*.............. + ldr q4, [x2, #-32] // .................*............ + ldr q28, [x2, #-16] // ...................*.......... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q26, [x3, #0] // *.............................. + // ldr q8, [x3, #16] // ..*............................ + // ldr q24, [x3, #32] // ....*.......................... + // ldr q16, [x3, #48] // ......*........................ + // trn1 v0.4S, v24.4S, v16.4S // ..........*.................... + // ldr q9, [x2], #(6*16) // ........*...................... + // ldr q6, [x2, #-80] // ...........*................... + // ldr q3, [x2, #-64] // .............*................. + // ldr q15, [x2, #-48] // ...............*............... + // ldr q4, [x2, #-32] // .................*............. + // ldr q28, [x2, #-16] // ...................*........... + + sub count, count, #1 +layer3456_start: + // Instructions: 83 + // Expected cycles: 94 + // Expected IPC: 0.88 + // + // Cycle bound: 94.0 + // IPC bound: 0.88 + // + // Wall time: 3.34s + // User time: 3.34s + // + // ------------------------------------- cycle (expected) --------------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------------------ + trn1 v12.4S, v26.4S, v8.4S // *............................................................................................. + trn2 v26.4S, v26.4S, v8.4S // .*............................................................................................ + trn2 v8.4S, v24.4S, v16.4S // ..*........................................................................................... + trn2 v11.2D, v12.2D, v0.2D // ...*.......................................................................................... + trn1 v12.2D, v12.2D, v0.2D // ....*......................................................................................... + trn2 v16.2D, v26.2D, v8.2D // .....*........................................................................................ + trn1 v26.2D, v26.2D, v8.2D // ......*....................................................................................... + sub v8.8H, v11.8H, v16.8H // .......*...................................................................................... + add v11.8H, v11.8H, v16.8H // ........*..................................................................................... + sub v16.8H, v12.8H, v26.8H // .........*.................................................................................... + add v12.8H, v12.8H, v26.8H // ..........*................................................................................... + sqrdmulh v26.8H, v8.8H, v28.8H // ...........*.................................................................................. + sqrdmulh v15.8H, v16.8H, v15.8H // ............*................................................................................. + mul v16.8H, v16.8H, v3.8H // .............*................................................................................ + mul v8.8H, v8.8H, v4.8H // ..............*............................................................................... + sub v0.8H, v12.8H, v11.8H // ...............*.............................................................................. + add v12.8H, v12.8H, v11.8H // ................*............................................................................. + mls v16.8H, v15.8H, v7.H[0] // .................*............................................................................ + mls v8.8H, v26.8H, v7.H[0] // ..................*........................................................................... + sqrdmulh v26.8H, v0.8H, v6.8H // ...................*.......................................................................... + mul v11.8H, v0.8H, v9.8H // ....................*......................................................................... + ldr q15, [x1], #16 // .....................*........................................................................ + sub v0.8H, v16.8H, v8.8H // .......................*...................................................................... + mls v11.8H, v26.8H, v7.H[0] // ........................*..................................................................... + add v26.8H, v16.8H, v8.8H // .........................*.................................................................... + sqrdmulh v8.8H, v0.8H, v6.8H // ..........................*................................................................... + mul v16.8H, v0.8H, v9.8H // ...........................*.................................................................. + trn1 v0.4S, v12.4S, v26.4S // ............................*................................................................. + trn2 v12.4S, v12.4S, v26.4S // .............................*................................................................ + ldr q26, [x3, #64] // ..............................e............................................................... + mls v16.8H, v8.8H, v7.H[0] // ................................*............................................................. + ldr q8, [x3, #80] // .................................e............................................................ + ldr q24, [x3, #96] // ...................................e.......................................................... + trn1 v9.4S, v11.4S, v16.4S // .....................................*........................................................ + trn2 v11.4S, v11.4S, v16.4S // ......................................*....................................................... + ldr q16, [x3, #112] // .......................................e...................................................... + trn2 v6.2D, v0.2D, v9.2D // .........................................*.................................................... + trn2 v3.2D, v12.2D, v11.2D // ..........................................*................................................... + trn1 v0.2D, v0.2D, v9.2D // ...........................................*.................................................. + trn1 v12.2D, v12.2D, v11.2D // ............................................*................................................. + sub v11.8H, v6.8H, v3.8H // .............................................*................................................ + sub v9.8H, v0.8H, v12.8H // ..............................................*............................................... + add v12.8H, v0.8H, v12.8H // ...............................................*.............................................. + sqrdmulh v0.8H, v11.8H, v15.H[5] // ................................................*............................................. + sqrdmulh v4.8H, v9.8H, v15.H[3] // .................................................*............................................ + mul v9.8H, v9.8H, v15.H[2] // ..................................................*........................................... + mul v11.8H, v11.8H, v15.H[4] // ...................................................*.......................................... + add v6.8H, v6.8H, v3.8H // ....................................................*......................................... + sqdmulh v3.8H, v12.8H, v7.H[1] // .....................................................*........................................ + mls v9.8H, v4.8H, v7.H[0] // ......................................................*....................................... + mls v11.8H, v0.8H, v7.H[0] // .......................................................*...................................... + sqdmulh v0.8H, v6.8H, v7.H[1] // ........................................................*..................................... + srshr v3.8H, v3.8H, #11 // .........................................................*.................................... + sqdmulh v4.8H, v9.8H, v7.H[1] // ..........................................................*................................... + sqdmulh v28.8H, v11.8H, v7.H[1] // ...........................................................*.................................. + mls v12.8H, v3.8H, v7.H[0] // ............................................................*................................. + srshr v0.8H, v0.8H, #11 // .............................................................*................................ + srshr v3.8H, v4.8H, #11 // ..............................................................*............................... + srshr v4.8H, v28.8H, #11 // ...............................................................*.............................. + mls v6.8H, v0.8H, v7.H[0] // ................................................................*............................. + mls v9.8H, v3.8H, v7.H[0] // .................................................................*............................ + mls v11.8H, v4.8H, v7.H[0] // ..................................................................*........................... + trn1 v0.4S, v24.4S, v16.4S // ...................................................................e.......................... + sub v3.8H, v12.8H, v6.8H // ....................................................................*......................... + add v12.8H, v12.8H, v6.8H // .....................................................................*........................ + sub v6.8H, v9.8H, v11.8H // ......................................................................*....................... + sqrdmulh v4.8H, v3.8H, v15.H[1] // .......................................................................*...................... + mul v3.8H, v3.8H, v15.H[0] // ........................................................................*..................... + sqrdmulh v28.8H, v6.8H, v15.H[1] // .........................................................................*.................... + mul v15.8H, v6.8H, v15.H[0] // ..........................................................................*................... + add v11.8H, v9.8H, v11.8H // ...........................................................................*.................. + mls v3.8H, v4.8H, v7.H[0] // ............................................................................*................. + str q12, [x3], #(64) // .............................................................................*................ + mls v15.8H, v28.8H, v7.H[0] // ..............................................................................*............... + str q11, [x3, #-48] // ...............................................................................*.............. + ldr q9, [x2], #(6*16) // ................................................................................e............. + str q3, [x3, #-32] // ..................................................................................*........... + ldr q6, [x2, #-80] // ...................................................................................e.......... + str q15, [x3, #-16] // .....................................................................................*........ + ldr q3, [x2, #-64] // ......................................................................................e....... + ldr q15, [x2, #-48] // ........................................................................................e..... + ldr q4, [x2, #-32] // ..........................................................................................e... + ldr q28, [x2, #-16] // ............................................................................................e. + + // ----------------------------------------------------------------- cycle (expected) ------------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------ + // ldr q8, [x3, #(16*0)] // e...............................................................'.............................~....................................................... + // ldr q9, [x3, #(16*1)] // ...e............................................................'................................~.................................................... + // ldr q10, [x3, #(16*2)] // .....e..........................................................'..................................~.................................................. + // ldr q11, [x3, #(16*3)] // .........e......................................................'......................................~.............................................. + // trn1 v25.4s, v8.4s, v9.4s // ................................................................*..................................................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'*.................................................................................... + // trn1 v27.4s, v10.4s, v11.4s // .....................................e..........................'..................................................................~.................. + // trn2 v28.4s, v10.4s, v11.4s // ................................................................'.*................................................................................... + // trn2 v10.2d, v25.2d, v27.2d // ................................................................'..*.................................................................................. + // trn2 v11.2d, v26.2d, v28.2d // ................................................................'....*................................................................................ + // trn1 v8.2d, v25.2d, v27.2d // ................................................................'...*................................................................................. + // trn1 v9.2d, v26.2d, v28.2d // ................................................................'.....*............................................................................... + // ldr q0, [x2], #(6*16) // ..................................................e.............'...............................................................................~..... + // ldr q4, [x2, #(-6*16 + 1*16)] // .....................................................e..........'..................................................................................~.. + // ldr q1, [x2, #(-6*16 + 2*16)] // ........................................................e.......'..................................................................................... + // ldr q5, [x2, #(-6*16 + 3*16)] // ..........................................................e.....'..................................................................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ............................................................e...'..................................................................................... + // ldr q6, [x2, #(-6*16 + 5*16)] // ..............................................................e.'..................................................................................... + // sub v24.8h, v8.8h, v9.8h // ................................................................'........*............................................................................ + // add v8.8h, v8.8h, v9.8h // ................................................................'.........*........................................................................... + // sqrdmulh v27.8h, v24.8h, v5.8h // ................................................................'...........*......................................................................... + // mul v9.8h, v24.8h, v1.8h // ................................................................'............*........................................................................ + // mls v9.8h, v27.8h, v7.h[0] // ................................................................'................*.................................................................... + // sub v24.8h, v10.8h, v11.8h // ................................................................'......*.............................................................................. + // add v10.8h, v10.8h, v11.8h // ................................................................'.......*............................................................................. + // sqrdmulh v27.8h, v24.8h, v6.8h // ................................................................'..........*.......................................................................... + // mul v11.8h, v24.8h, v2.8h // ................................................................'.............*....................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ................................................................'.................*................................................................... + // sub v24.8h, v8.8h, v10.8h // ................................................................'..............*...................................................................... + // add v8.8h, v8.8h, v10.8h // ................................................................'...............*..................................................................... + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'..................*.................................................................. + // mul v10.8h, v24.8h, v0.8h // ................................................................'...................*................................................................. + // mls v10.8h, v27.8h, v7.h[0] // ................................................................'.......................*............................................................. + // sub v24.8h, v9.8h, v11.8h // ................................................................'......................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ................................................................'........................*............................................................ + // sqrdmulh v27.8h, v24.8h, v4.8h // ................................................................'.........................*........................................................... + // mul v11.8h, v24.8h, v0.8h // ................................................................'..........................*.......................................................... + // mls v11.8h, v27.8h, v7.h[0] // ..~.............................................................'...............................*..................................................... + // trn1 v25.4s, v8.4s, v9.4s // ................................................................'...........................*......................................................... + // trn2 v26.4s, v8.4s, v9.4s // ................................................................'............................*........................................................ + // trn1 v27.4s, v10.4s, v11.4s // .......~........................................................'....................................*................................................ + // trn2 v28.4s, v10.4s, v11.4s // ........~.......................................................'.....................................*............................................... + // trn2 v10.2d, v25.2d, v27.2d // ...........~....................................................'........................................*............................................ + // trn2 v11.2d, v26.2d, v28.2d // ............~...................................................'.........................................*........................................... + // trn1 v8.2d, v25.2d, v27.2d // .............~..................................................'..........................................*.......................................... + // trn1 v9.2d, v26.2d, v28.2d // ..............~.................................................'...........................................*......................................... + // ldr q0, [x1], #16 // ................................................................'....................*................................................................ + // sub v24.8h, v8.8h, v9.8h // ................~...............................................'.............................................*....................................... + // add v8.8h, v8.8h, v9.8h // .................~..............................................'..............................................*...................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...................~............................................'................................................*.................................... + // mul v9.8h, v24.8h, v0.h[2] // ....................~...........................................'.................................................*................................... + // mls v9.8h, v27.8h, v7.h[0] // ........................~.......................................'.....................................................*............................... + // sub v24.8h, v10.8h, v11.8h // ...............~................................................'............................................*........................................ + // add v10.8h, v10.8h, v11.8h // ......................~.........................................'...................................................*................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ..................~.............................................'...............................................*..................................... + // mul v11.8h, v24.8h, v0.h[4] // .....................~..........................................'..................................................*.................................. + // mls v11.8h, v27.8h, v7.h[0] // .........................~......................................'......................................................*.............................. + // sqdmulh v25.8h, v8.8h, v7.h[1] // .......................~........................................'....................................................*................................ + // srshr v25.8h, v25.8h, #11 // ...........................~....................................'........................................................*............................ + // mls v8.8h, v25.8h, v7.h[0] // ..............................~.................................'...........................................................*......................... + // sqdmulh v25.8h, v10.8h, v7.h[1] // ..........................~.....................................'.......................................................*............................. + // srshr v25.8h, v25.8h, #11 // ...............................~................................'............................................................*........................ + // mls v10.8h, v25.8h, v7.h[0] // ..................................~.............................'...............................................................*..................... + // sqdmulh v25.8h, v9.8h, v7.h[1] // ............................~...................................'.........................................................*........................... + // srshr v25.8h, v25.8h, #11 // ................................~...............................'.............................................................*....................... + // mls v9.8h, v25.8h, v7.h[0] // ...................................~............................'................................................................*.................... + // sqdmulh v25.8h, v11.8h, v7.h[1] // .............................~..................................'..........................................................*.......................... + // srshr v25.8h, v25.8h, #11 // .................................~..............................'..............................................................*...................... + // mls v11.8h, v25.8h, v7.h[0] // ....................................~...........................'.................................................................*................... + // sub v24.8h, v8.8h, v10.8h // ......................................~.........................'...................................................................*................. + // add v8.8h, v8.8h, v10.8h // .......................................~........................'....................................................................*................ + // sqrdmulh v27.8h, v24.8h, v0.h[1] // .........................................~......................'......................................................................*.............. + // mul v10.8h, v24.8h, v0.h[0] // ..........................................~.....................'.......................................................................*............. + // mls v10.8h, v27.8h, v7.h[0] // ..............................................~.................'...........................................................................*......... + // sub v24.8h, v9.8h, v11.8h // ........................................~.......................'.....................................................................*............... + // add v9.8h, v9.8h, v11.8h // .............................................~..................'..........................................................................*.......... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................................~....................'........................................................................*............ + // mul v11.8h, v24.8h, v0.h[0] // ............................................~...................'.........................................................................*........... + // mls v11.8h, v27.8h, v7.h[0] // ................................................~...............'.............................................................................*....... + // str q8, [x3], #(64) // ...............................................~................'............................................................................*........ + // str q9, [x3, #(-64 + 16*1)] // .................................................~..............'..............................................................................*...... + // str q10, [x3, #(-64 + 16*2)] // ....................................................~...........'.................................................................................*... + // str q11, [x3, #(-64 + 16*3)] // .......................................................~........'....................................................................................* + + sub count, count, #1 + cbnz count, layer3456_start + // Instructions: 72 + // Expected cycles: 79 + // Expected IPC: 0.91 + // + // Cycle bound: 79.0 + // IPC bound: 0.91 + // + // Wall time: 9.28s + // User time: 9.28s + // + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + trn1 v11.4S, v26.4S, v8.4S // *.............................................................................. + trn2 v24.4S, v24.4S, v16.4S // .*............................................................................. + trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + trn1 v18.2D, v11.2D, v0.2D // ...*........................................................................... + trn2 v11.2D, v11.2D, v0.2D // ....*.......................................................................... + trn2 v12.2D, v26.2D, v24.2D // .....*......................................................................... + trn1 v8.2D, v26.2D, v24.2D // ......*........................................................................ + sub v26.8H, v11.8H, v12.8H // .......*....................................................................... + sub v13.8H, v18.8H, v8.8H // ........*...................................................................... + add v24.8H, v18.8H, v8.8H // .........*..................................................................... + mul v16.8H, v26.8H, v4.8H // ..........*.................................................................... + sqrdmulh v17.8H, v13.8H, v15.8H // ...........*................................................................... + mul v3.8H, v13.8H, v3.8H // ............*.................................................................. + sqrdmulh v26.8H, v26.8H, v28.8H // .............*................................................................. + add v10.8H, v11.8H, v12.8H // ..............*................................................................ + mls v3.8H, v17.8H, v7.H[0] // ................*.............................................................. + mls v16.8H, v26.8H, v7.H[0] // .................*............................................................. + sub v26.8H, v24.8H, v10.8H // ..................*............................................................ + ldr q4, [x1], #16 // ...................*........................................................... + sub v12.8H, v3.8H, v16.8H // .....................*......................................................... + sqrdmulh v15.8H, v26.8H, v6.8H // ......................*........................................................ + mul v11.8H, v26.8H, v9.8H // .......................*....................................................... + mul v8.8H, v12.8H, v9.8H // ........................*...................................................... + sqrdmulh v12.8H, v12.8H, v6.8H // .........................*..................................................... + add v0.8H, v24.8H, v10.8H // ..........................*.................................................... + mls v11.8H, v15.8H, v7.H[0] // ...........................*................................................... + add v6.8H, v3.8H, v16.8H // ............................*.................................................. + mls v8.8H, v12.8H, v7.H[0] // .............................*................................................. + trn2 v26.4S, v0.4S, v6.4S // ...............................*............................................... + trn2 v12.4S, v11.4S, v8.4S // .................................*............................................. + trn1 v3.4S, v11.4S, v8.4S // ..................................*............................................ + trn1 v17.4S, v0.4S, v6.4S // ...................................*........................................... + trn1 v8.2D, v26.2D, v12.2D // ....................................*.......................................... + trn2 v13.2D, v26.2D, v12.2D // .....................................*......................................... + trn1 v11.2D, v17.2D, v3.2D // ......................................*........................................ + trn2 v15.2D, v17.2D, v3.2D // .......................................*....................................... + sub v12.8H, v11.8H, v8.8H // ........................................*...................................... + add v16.8H, v15.8H, v13.8H // .........................................*..................................... + sub v26.8H, v15.8H, v13.8H // ..........................................*.................................... + mul v0.8H, v12.8H, v4.H[2] // ...........................................*................................... + sqrdmulh v9.8H, v12.8H, v4.H[3] // ............................................*.................................. + mul v13.8H, v26.8H, v4.H[4] // .............................................*................................. + sqrdmulh v26.8H, v26.8H, v4.H[5] // ..............................................*................................ + add v24.8H, v11.8H, v8.8H // ...............................................*............................... + mls v0.8H, v9.8H, v7.H[0] // ................................................*.............................. + sqdmulh v12.8H, v16.8H, v7.H[1] // .................................................*............................. + mls v13.8H, v26.8H, v7.H[0] // ..................................................*............................ + sqdmulh v11.8H, v24.8H, v7.H[1] // ...................................................*........................... + sqdmulh v8.8H, v0.8H, v7.H[1] // ....................................................*.......................... + srshr v12.8H, v12.8H, #11 // .....................................................*......................... + sqdmulh v26.8H, v13.8H, v7.H[1] // ......................................................*........................ + srshr v11.8H, v11.8H, #11 // .......................................................*....................... + mls v16.8H, v12.8H, v7.H[0] // ........................................................*...................... + srshr v8.8H, v8.8H, #11 // .........................................................*..................... + srshr v26.8H, v26.8H, #11 // ..........................................................*.................... + mls v24.8H, v11.8H, v7.H[0] // ...........................................................*................... + mls v0.8H, v8.8H, v7.H[0] // ............................................................*.................. + mls v13.8H, v26.8H, v7.H[0] // .............................................................*................. + sub v26.8H, v24.8H, v16.8H // ...............................................................*............... + add v15.8H, v24.8H, v16.8H // ................................................................*.............. + sub v12.8H, v0.8H, v13.8H // .................................................................*............. + mul v11.8H, v26.8H, v4.H[0] // ..................................................................*............ + sqrdmulh v16.8H, v26.8H, v4.H[1] // ...................................................................*........... + mul v26.8H, v12.8H, v4.H[0] // ....................................................................*.......... + sqrdmulh v8.8H, v12.8H, v4.H[1] // .....................................................................*......... + add v12.8H, v0.8H, v13.8H // ......................................................................*........ + mls v11.8H, v16.8H, v7.H[0] // .......................................................................*....... + str q15, [x3], #(64) // ........................................................................*...... + mls v26.8H, v8.8H, v7.H[0] // .........................................................................*..... + str q12, [x3, #-48] // ..........................................................................*.... + str q11, [x3, #-32] // ............................................................................*.. + str q26, [x3, #-16] // ..............................................................................* + + // ------------------------------ cycle (expected) ------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|--- + // trn1 v12.4S, v26.4S, v8.4S // *.............................................................................. + // trn2 v26.4S, v26.4S, v8.4S // ..*............................................................................ + // trn2 v8.4S, v24.4S, v16.4S // .*............................................................................. + // trn2 v11.2D, v12.2D, v0.2D // ....*.......................................................................... + // trn1 v12.2D, v12.2D, v0.2D // ...*........................................................................... + // trn2 v16.2D, v26.2D, v8.2D // .....*......................................................................... + // trn1 v26.2D, v26.2D, v8.2D // ......*........................................................................ + // sub v8.8H, v11.8H, v16.8H // .......*....................................................................... + // add v11.8H, v11.8H, v16.8H // ..............*................................................................ + // sub v16.8H, v12.8H, v26.8H // ........*...................................................................... + // add v12.8H, v12.8H, v26.8H // .........*..................................................................... + // sqrdmulh v26.8H, v8.8H, v28.8H // .............*................................................................. + // sqrdmulh v15.8H, v16.8H, v15.8H // ...........*................................................................... + // mul v16.8H, v16.8H, v3.8H // ............*.................................................................. + // mul v8.8H, v8.8H, v4.8H // ..........*.................................................................... + // sub v0.8H, v12.8H, v11.8H // ..................*............................................................ + // add v12.8H, v12.8H, v11.8H // ..........................*.................................................... + // mls v16.8H, v15.8H, v7.H[0] // ................*.............................................................. + // mls v8.8H, v26.8H, v7.H[0] // .................*............................................................. + // sqrdmulh v26.8H, v0.8H, v6.8H // ......................*........................................................ + // mul v11.8H, v0.8H, v9.8H // .......................*....................................................... + // ldr q15, [x1], #16 // ...................*........................................................... + // sub v0.8H, v16.8H, v8.8H // .....................*......................................................... + // mls v11.8H, v26.8H, v7.H[0] // ...........................*................................................... + // add v26.8H, v16.8H, v8.8H // ............................*.................................................. + // sqrdmulh v8.8H, v0.8H, v6.8H // .........................*..................................................... + // mul v16.8H, v0.8H, v9.8H // ........................*...................................................... + // trn1 v0.4S, v12.4S, v26.4S // ...................................*........................................... + // trn2 v12.4S, v12.4S, v26.4S // ...............................*............................................... + // mls v16.8H, v8.8H, v7.H[0] // .............................*................................................. + // trn1 v9.4S, v11.4S, v16.4S // ..................................*............................................ + // trn2 v11.4S, v11.4S, v16.4S // .................................*............................................. + // trn2 v6.2D, v0.2D, v9.2D // .......................................*....................................... + // trn2 v3.2D, v12.2D, v11.2D // .....................................*......................................... + // trn1 v0.2D, v0.2D, v9.2D // ......................................*........................................ + // trn1 v12.2D, v12.2D, v11.2D // ....................................*.......................................... + // sub v11.8H, v6.8H, v3.8H // ..........................................*.................................... + // sub v9.8H, v0.8H, v12.8H // ........................................*...................................... + // add v12.8H, v0.8H, v12.8H // ...............................................*............................... + // sqrdmulh v0.8H, v11.8H, v15.H[5] // ..............................................*................................ + // sqrdmulh v4.8H, v9.8H, v15.H[3] // ............................................*.................................. + // mul v9.8H, v9.8H, v15.H[2] // ...........................................*................................... + // mul v11.8H, v11.8H, v15.H[4] // .............................................*................................. + // add v6.8H, v6.8H, v3.8H // .........................................*..................................... + // sqdmulh v3.8H, v12.8H, v7.H[1] // ...................................................*........................... + // mls v9.8H, v4.8H, v7.H[0] // ................................................*.............................. + // mls v11.8H, v0.8H, v7.H[0] // ..................................................*............................ + // sqdmulh v0.8H, v6.8H, v7.H[1] // .................................................*............................. + // srshr v3.8H, v3.8H, #11 // .......................................................*....................... + // sqdmulh v4.8H, v9.8H, v7.H[1] // ....................................................*.......................... + // sqdmulh v28.8H, v11.8H, v7.H[1] // ......................................................*........................ + // mls v12.8H, v3.8H, v7.H[0] // ...........................................................*................... + // srshr v0.8H, v0.8H, #11 // .....................................................*......................... + // srshr v3.8H, v4.8H, #11 // .........................................................*..................... + // srshr v4.8H, v28.8H, #11 // ..........................................................*.................... + // mls v6.8H, v0.8H, v7.H[0] // ........................................................*...................... + // mls v9.8H, v3.8H, v7.H[0] // ............................................................*.................. + // mls v11.8H, v4.8H, v7.H[0] // .............................................................*................. + // sub v3.8H, v12.8H, v6.8H // ...............................................................*............... + // add v12.8H, v12.8H, v6.8H // ................................................................*.............. + // sub v6.8H, v9.8H, v11.8H // .................................................................*............. + // sqrdmulh v4.8H, v3.8H, v15.H[1] // ...................................................................*........... + // mul v3.8H, v3.8H, v15.H[0] // ..................................................................*............ + // sqrdmulh v28.8H, v6.8H, v15.H[1] // .....................................................................*......... + // mul v15.8H, v6.8H, v15.H[0] // ....................................................................*.......... + // add v11.8H, v9.8H, v11.8H // ......................................................................*........ + // mls v3.8H, v4.8H, v7.H[0] // .......................................................................*....... + // str q12, [x3], #(64) // ........................................................................*...... + // mls v15.8H, v28.8H, v7.H[0] // .........................................................................*..... + // str q11, [x3, #-48] // ..........................................................................*.... + // str q3, [x3, #-32] // ............................................................................*.. + // str q15, [x3, #-16] // ..............................................................................* + + + // --------------------------------------------------------------------- + + mov count, #4 + load_roots_012 + + .p2align 2 + + // Instructions: 12 + // Expected cycles: 19 + // Expected IPC: 0.63 + // + // Cycle bound: 19.0 + // IPC bound: 0.63 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q24, [x0, #128] // *............................. + ldr q16, [x0, #192] // ..*........................... + ldr q9, [x0, #256] // ....*......................... + ldr q6, [x0, #320] // ......*....................... + ldr q3, [x0, #384] // ........*..................... + ldr q4, [x0, #448] // ..........*................... + add v28.8H, v9.8H, v6.8H // ............*................. + add v19.8H, v24.8H, v16.8H // .............*................ + add v13.8H, v3.8H, v4.8H // ..............*............... + ldr q11, [x0, #0] // ...............*.............. + add v23.8H, v28.8H, v13.8H // .................*............ + ldr q15, [x0, #64] // ..................*........... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q11, [x0, #0] // ...............*............... + // ldr q15, [x0, #64] // ..................*............ + // ldr q24, [x0, #128] // *.............................. + // ldr q16, [x0, #192] // ..*............................ + // ldr q9, [x0, #256] // ....*.......................... + // ldr q6, [x0, #320] // ......*........................ + // ldr q3, [x0, #384] // ........*...................... + // ldr q4, [x0, #448] // ..........*.................... + // add v28.8H, v9.8H, v6.8H // ............*.................. + // add v13.8H, v3.8H, v4.8H // ..............*................ + // add v19.8H, v24.8H, v16.8H // .............*................. + // add v23.8H, v28.8H, v13.8H // .................*............. + + sub count, count, #1 +layer012_start: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.81s + // User time: 2.81s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sub v12.8H, v11.8H, v15.8H // *................................................................................... + add v26.8H, v11.8H, v15.8H // .*.................................................................................. + sub v8.8H, v24.8H, v16.8H // ..*................................................................................. + sqrdmulh v11.8H, v12.8H, v0.H[7] // ...*................................................................................ + mul v12.8H, v12.8H, v0.H[6] // ....*............................................................................... + sub v16.8H, v26.8H, v19.8H // .....*.............................................................................. + add v26.8H, v26.8H, v19.8H // ......*............................................................................. + sqrdmulh v15.8H, v8.8H, v1.H[1] // .......*............................................................................ + mul v8.8H, v8.8H, v1.H[0] // ........*........................................................................... + mls v12.8H, v11.8H, v7.H[0] // .........*.......................................................................... + sub v11.8H, v9.8H, v6.8H // ..........*......................................................................... + sqrdmulh v24.8H, v16.8H, v0.H[3] // ...........*........................................................................ + mul v16.8H, v16.8H, v0.H[2] // ............*....................................................................... + sub v9.8H, v26.8H, v23.8H // .............*...................................................................... + add v26.8H, v26.8H, v23.8H // ..............*..................................................................... + mls v8.8H, v15.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v15.8H, v11.8H, v1.H[3] // ................*................................................................... + mul v11.8H, v11.8H, v1.H[2] // .................*.................................................................. + sub v6.8H, v3.8H, v4.8H // ..................*................................................................. + sub v3.8H, v12.8H, v8.8H // ...................*................................................................ + add v12.8H, v12.8H, v8.8H // ....................*............................................................... + mls v11.8H, v15.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v8.8H, v6.8H, v1.H[5] // ......................*............................................................. + mls v16.8H, v24.8H, v7.H[0] // .......................*............................................................ + mul v15.8H, v6.8H, v1.H[4] // ........................*........................................................... + sqrdmulh v24.8H, v3.8H, v0.H[3] // .........................*.......................................................... + mul v6.8H, v3.8H, v0.H[2] // ..........................*......................................................... + sqrdmulh v3.8H, v9.8H, v0.H[1] // ...........................*........................................................ + mul v9.8H, v9.8H, v0.H[0] // ............................*....................................................... + str q26, [x0], #(16) // .............................*...................................................... + mls v15.8H, v8.8H, v7.H[0] // ..............................*..................................................... + mls v6.8H, v24.8H, v7.H[0] // ...............................*.................................................... + sub v26.8H, v28.8H, v13.8H // ................................*................................................... + mls v9.8H, v3.8H, v7.H[0] // .................................*.................................................. + sub v8.8H, v11.8H, v15.8H // ..................................*................................................. + sqrdmulh v24.8H, v26.8H, v0.H[5] // ...................................*................................................ + mul v26.8H, v26.8H, v0.H[4] // ....................................*............................................... + add v11.8H, v11.8H, v15.8H // .....................................*.............................................. + sqrdmulh v15.8H, v8.8H, v0.H[5] // ......................................*............................................. + mul v8.8H, v8.8H, v0.H[4] // .......................................*............................................ + mls v26.8H, v24.8H, v7.H[0] // ........................................*........................................... + sub v24.8H, v12.8H, v11.8H // .........................................*.......................................... + add v12.8H, v12.8H, v11.8H // ..........................................*......................................... + mls v8.8H, v15.8H, v7.H[0] // ...........................................*........................................ + sqrdmulh v11.8H, v24.8H, v0.H[1] // ............................................*....................................... + mul v15.8H, v24.8H, v0.H[0] // .............................................*...................................... + sub v24.8H, v16.8H, v26.8H // ..............................................*..................................... + add v26.8H, v16.8H, v26.8H // ...............................................*.................................... + sub v16.8H, v6.8H, v8.8H // ................................................*................................... + mls v15.8H, v11.8H, v7.H[0] // .................................................*.................................. + sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................................................*................................. + mul v24.8H, v24.8H, v0.H[0] // ...................................................*................................ + add v8.8H, v6.8H, v8.8H // ....................................................*............................... + sqrdmulh v6.8H, v16.8H, v0.H[1] // .....................................................*.............................. + mul v16.8H, v16.8H, v0.H[0] // ......................................................*............................. + mls v24.8H, v11.8H, v7.H[0] // .......................................................*............................ + str q9, [x0, #240] // ........................................................*........................... + ldr q11, [x0, #0] // .........................................................e.......................... + mls v16.8H, v6.8H, v7.H[0] // ...........................................................*........................ + str q15, [x0, #304] // ............................................................*....................... + ldr q15, [x0, #64] // .............................................................e...................... + str q24, [x0, #368] // ...............................................................*.................... + ldr q24, [x0, #128] // ................................................................e................... + str q16, [x0, #432] // ..................................................................*................. + ldr q16, [x0, #192] // ...................................................................e................ + str q12, [x0, #48] // .....................................................................*.............. + ldr q9, [x0, #256] // ......................................................................e............. + ldr q6, [x0, #320] // ........................................................................e........... + ldr q3, [x0, #384] // ..........................................................................e......... + ldr q4, [x0, #448] // ............................................................................e....... + str q26, [x0, #112] // ..............................................................................*..... + add v28.8H, v9.8H, v6.8H // ...............................................................................e.... + add v13.8H, v3.8H, v4.8H // ................................................................................e... + str q8, [x0, #176] // .................................................................................*.. + add v19.8H, v24.8H, v16.8H // ..................................................................................e. + add v23.8H, v28.8H, v13.8H // ...................................................................................e + + // --------------------------------------------- cycle (expected) ---------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|-------- + // ldr q8, [x0, #0] // e..........................'........................................................~........................ + // ldr q9, [x0, #(1*(512/8))] // ....e......................'............................................................~.................... + // ldr q10, [x0, #(2*(512/8))] // .......e...................'...............................................................~................. + // ldr q11, [x0, #(3*(512/8))] // ..........e................'..................................................................~.............. + // ldr q12, [x0, #(4*(512/8))] // .............e.............'.....................................................................~........... + // ldr q13, [x0, #(5*(512/8))] // ...............e...........'.......................................................................~......... + // ldr q14, [x0, #(6*(512/8))] // .................e.........'.........................................................................~....... + // ldr q15, [x0, #(7*(512/8))] // ...................e.......'...........................................................................~..... + // sub v24.8h, v8.8h, v9.8h // ...........................*................................................................................. + // add v8.8h, v8.8h, v9.8h // ...........................'*................................................................................ + // sqrdmulh v27.8h, v24.8h, v0.h[7] // ...........................'..*.............................................................................. + // mul v9.8h, v24.8h, v0.h[6] // ...........................'...*............................................................................. + // mls v9.8h, v27.8h, v7.h[0] // ...........................'........*........................................................................ + // sub v24.8h, v10.8h, v11.8h // ...........................'.*............................................................................... + // add v10.8h, v10.8h, v11.8h // .........................e.'................................................................................. + // sqrdmulh v27.8h, v24.8h, v1.h[1] // ...........................'......*.......................................................................... + // mul v11.8h, v24.8h, v1.h[0] // ...........................'.......*......................................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............*.................................................................. + // sub v24.8h, v12.8h, v13.8h // ...........................'.........*....................................................................... + // add v12.8h, v12.8h, v13.8h // ......................e....'..............................................................................~.. + // sqrdmulh v27.8h, v24.8h, v1.h[3] // ...........................'...............*................................................................. + // mul v13.8h, v24.8h, v1.h[2] // ...........................'................*................................................................ + // mls v13.8h, v27.8h, v7.h[0] // ...........................'....................*............................................................ + // sub v24.8h, v14.8h, v15.8h // ...........................'.................*............................................................... + // add v14.8h, v14.8h, v15.8h // .......................e...'...............................................................................~. + // sqrdmulh v27.8h, v24.8h, v1.h[5] // ...........................'.....................*........................................................... + // mul v15.8h, v24.8h, v1.h[4] // ...........................'.......................*......................................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................... + // sub v24.8h, v8.8h, v10.8h // ...........................'....*............................................................................ + // add v8.8h, v8.8h, v10.8h // ...........................'.....*........................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'..........*...................................................................... + // mul v10.8h, v24.8h, v0.h[2] // ...........................'...........*..................................................................... + // mls v10.8h, v27.8h, v7.h[0] // ...........................'......................*.......................................................... + // sub v24.8h, v9.8h, v11.8h // ...........................'..................*.............................................................. + // add v9.8h, v9.8h, v11.8h // ...........................'...................*............................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[3] // ...........................'........................*........................................................ + // mul v11.8h, v24.8h, v0.h[2] // ...........................'.........................*....................................................... + // mls v11.8h, v27.8h, v7.h[0] // ...........................'..............................*.................................................. + // sub v24.8h, v12.8h, v14.8h // ...........................'...............................*................................................. + // add v12.8h, v12.8h, v14.8h // ..........................e'................................................................................. + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'..................................*.............................................. + // mul v14.8h, v24.8h, v0.h[4] // ...........................'...................................*............................................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'.......................................*......................................... + // sub v24.8h, v13.8h, v15.8h // ...........................'.................................*............................................... + // add v13.8h, v13.8h, v15.8h // ...........................'....................................*............................................ + // sqrdmulh v27.8h, v24.8h, v0.h[5] // ...........................'.....................................*........................................... + // mul v15.8h, v24.8h, v0.h[4] // ...........................'......................................*.......................................... + // mls v15.8h, v27.8h, v7.h[0] // ...........................'..........................................*...................................... + // sub v24.8h, v8.8h, v12.8h // ...........................'............*.................................................................... + // add v8.8h, v8.8h, v12.8h // ...........................'.............*................................................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'..........................*...................................................... + // mul v12.8h, v24.8h, v0.h[0] // ...........................'...........................*..................................................... + // mls v12.8h, v27.8h, v7.h[0] // ...........................'................................*................................................ + // sub v24.8h, v9.8h, v13.8h // ...........................'........................................*........................................ + // add v9.8h, v9.8h, v13.8h // ...........................'.........................................*....................................... + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'...........................................*..................................... + // mul v13.8h, v24.8h, v0.h[0] // ...........................'............................................*.................................... + // mls v13.8h, v27.8h, v7.h[0] // ...........................'................................................*................................ + // sub v24.8h, v10.8h, v14.8h // ...........................'.............................................*................................... + // add v10.8h, v10.8h, v14.8h // ...........................'..............................................*.................................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'.................................................*............................... + // mul v14.8h, v24.8h, v0.h[0] // ...........................'..................................................*.............................. + // mls v14.8h, v27.8h, v7.h[0] // ...........................'......................................................*.......................... + // sub v24.8h, v11.8h, v15.8h // ...........................'...............................................*................................. + // add v11.8h, v11.8h, v15.8h // ...........................'...................................................*............................. + // sqrdmulh v27.8h, v24.8h, v0.h[1] // ...........................'....................................................*............................ + // mul v15.8h, v24.8h, v0.h[0] // ...........................'.....................................................*........................... + // mls v15.8h, v27.8h, v7.h[0] // ..~........................'..........................................................*...................... + // str q12, [x0, #(4*(512/8))] // ...........................'.......................................................*......................... + // str q13, [x0, #(5*(512/8))] // ...~.......................'...........................................................*..................... + // str q14, [x0, #(6*(512/8))] // ......~....................'..............................................................*.................. + // str q15, [x0, #(7*(512/8))] // .........~.................'.................................................................*............... + // str q8, [x0], #(16) // ...........................'............................*.................................................... + // str q9, [x0, #(-16 + 1*(512/8))] // ............~..............'....................................................................*............ + // str q10, [x0, #(-16 + 2*(512/8))] // .....................~.....'.............................................................................*... + // str q11, [x0, #(-16 + 3*(512/8))] // ........................~..'................................................................................* + + sub count, count, #1 + cbnz count, layer012_start + // Instructions: 64 + // Expected cycles: 66 + // Expected IPC: 0.97 + // + // Cycle bound: 66.0 + // IPC bound: 0.97 + // + // Wall time: 8.33s + // User time: 8.33s + // + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + add v10.8H, v11.8H, v15.8H // *................................................................. + sub v12.8H, v28.8H, v13.8H // .*................................................................ + sub v11.8H, v11.8H, v15.8H // ..*............................................................... + sub v22.8H, v10.8H, v19.8H // ...*.............................................................. + mul v18.8H, v12.8H, v0.H[4] // ....*............................................................. + sqrdmulh v26.8H, v12.8H, v0.H[5] // .....*............................................................ + sqrdmulh v12.8H, v22.8H, v0.H[3] // ......*........................................................... + mul v13.8H, v22.8H, v0.H[2] // .......*.......................................................... + sub v31.8H, v24.8H, v16.8H // ........*......................................................... + sqrdmulh v22.8H, v11.8H, v0.H[7] // .........*........................................................ + mls v18.8H, v26.8H, v7.H[0] // ..........*....................................................... + mls v13.8H, v12.8H, v7.H[0] // ...........*...................................................... + sqrdmulh v2.8H, v31.8H, v1.H[1] // ............*..................................................... + mul v5.8H, v31.8H, v1.H[0] // .............*.................................................... + mul v15.8H, v11.8H, v0.H[6] // ..............*................................................... + sub v12.8H, v13.8H, v18.8H // ...............*.................................................. + sub v4.8H, v3.8H, v4.8H // ................*................................................. + mls v5.8H, v2.8H, v7.H[0] // .................*................................................ + sqrdmulh v26.8H, v12.8H, v0.H[1] // ..................*............................................... + mul v12.8H, v12.8H, v0.H[0] // ...................*.............................................. + mls v15.8H, v22.8H, v7.H[0] // ....................*............................................. + sqrdmulh v8.8H, v4.8H, v1.H[5] // .....................*............................................ + mul v4.8H, v4.8H, v1.H[4] // ......................*........................................... + mls v12.8H, v26.8H, v7.H[0] // .......................*.......................................... + sub v21.8H, v15.8H, v5.8H // ........................*......................................... + sub v28.8H, v9.8H, v6.8H // .........................*........................................ + mls v4.8H, v8.8H, v7.H[0] // ..........................*....................................... + mul v24.8H, v21.8H, v0.H[2] // ...........................*...................................... + sqrdmulh v8.8H, v21.8H, v0.H[3] // ............................*..................................... + sqrdmulh v6.8H, v28.8H, v1.H[3] // .............................*.................................... + add v19.8H, v10.8H, v19.8H // ..............................*................................... + mul v28.8H, v28.8H, v1.H[2] // ...............................*.................................. + mls v24.8H, v8.8H, v7.H[0] // ................................*................................. + sub v11.8H, v19.8H, v23.8H // .................................*................................ + str q12, [x0, #384] // ..................................*............................... + mls v28.8H, v6.8H, v7.H[0] // ...................................*.............................. + sqrdmulh v16.8H, v11.8H, v0.H[1] // ....................................*............................. + mul v9.8H, v11.8H, v0.H[0] // .....................................*............................ + add v6.8H, v15.8H, v5.8H // ......................................*........................... + add v26.8H, v28.8H, v4.8H // .......................................*.......................... + sub v15.8H, v28.8H, v4.8H // ........................................*......................... + mls v9.8H, v16.8H, v7.H[0] // .........................................*........................ + add v3.8H, v6.8H, v26.8H // ..........................................*....................... + mul v8.8H, v15.8H, v0.H[4] // ...........................................*...................... + sqrdmulh v15.8H, v15.8H, v0.H[5] // ............................................*..................... + str q9, [x0, #256] // .............................................*.................... + sub v2.8H, v6.8H, v26.8H // ..............................................*................... + str q3, [x0, #64] // ...............................................*.................. + mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + sqrdmulh v15.8H, v2.8H, v0.H[1] // .................................................*................ + mul v11.8H, v2.8H, v0.H[0] // ..................................................*............... + add v16.8H, v13.8H, v18.8H // ...................................................*.............. + sub v12.8H, v24.8H, v8.8H // ....................................................*............. + add v8.8H, v24.8H, v8.8H // .....................................................*............ + mls v11.8H, v15.8H, v7.H[0] // ......................................................*........... + sqrdmulh v26.8H, v12.8H, v0.H[1] // .......................................................*.......... + mul v12.8H, v12.8H, v0.H[0] // ........................................................*......... + str q8, [x0, #192] // .........................................................*........ + add v15.8H, v19.8H, v23.8H // ..........................................................*....... + str q11, [x0, #320] // ...........................................................*...... + mls v12.8H, v26.8H, v7.H[0] // ............................................................*..... + str q15, [x0], #(16) // .............................................................*.... + str q16, [x0, #112] // ...............................................................*.. + str q12, [x0, #432] // .................................................................* + + // ----------------------- cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|--------------- + // sub v12.8H, v11.8H, v15.8H // ..*............................................................... + // add v26.8H, v11.8H, v15.8H // *................................................................. + // sub v8.8H, v24.8H, v16.8H // ........*......................................................... + // sqrdmulh v11.8H, v12.8H, v0.H[7] // .........*........................................................ + // mul v12.8H, v12.8H, v0.H[6] // ..............*................................................... + // sub v16.8H, v26.8H, v19.8H // ...*.............................................................. + // add v26.8H, v26.8H, v19.8H // ..............................*................................... + // sqrdmulh v15.8H, v8.8H, v1.H[1] // ............*..................................................... + // mul v8.8H, v8.8H, v1.H[0] // .............*.................................................... + // mls v12.8H, v11.8H, v7.H[0] // ....................*............................................. + // sub v11.8H, v9.8H, v6.8H // .........................*........................................ + // sqrdmulh v24.8H, v16.8H, v0.H[3] // ......*........................................................... + // mul v16.8H, v16.8H, v0.H[2] // .......*.......................................................... + // sub v9.8H, v26.8H, v23.8H // .................................*................................ + // add v26.8H, v26.8H, v23.8H // ..........................................................*....... + // mls v8.8H, v15.8H, v7.H[0] // .................*................................................ + // sqrdmulh v15.8H, v11.8H, v1.H[3] // .............................*.................................... + // mul v11.8H, v11.8H, v1.H[2] // ...............................*.................................. + // sub v6.8H, v3.8H, v4.8H // ................*................................................. + // sub v3.8H, v12.8H, v8.8H // ........................*......................................... + // add v12.8H, v12.8H, v8.8H // ......................................*........................... + // mls v11.8H, v15.8H, v7.H[0] // ...................................*.............................. + // sqrdmulh v8.8H, v6.8H, v1.H[5] // .....................*............................................ + // mls v16.8H, v24.8H, v7.H[0] // ...........*...................................................... + // mul v15.8H, v6.8H, v1.H[4] // ......................*........................................... + // sqrdmulh v24.8H, v3.8H, v0.H[3] // ............................*..................................... + // mul v6.8H, v3.8H, v0.H[2] // ...........................*...................................... + // sqrdmulh v3.8H, v9.8H, v0.H[1] // ....................................*............................. + // mul v9.8H, v9.8H, v0.H[0] // .....................................*............................ + // str q26, [x0], #(16) // .............................................................*.... + // mls v15.8H, v8.8H, v7.H[0] // ..........................*....................................... + // mls v6.8H, v24.8H, v7.H[0] // ................................*................................. + // sub v26.8H, v28.8H, v13.8H // .*................................................................ + // mls v9.8H, v3.8H, v7.H[0] // .........................................*........................ + // sub v8.8H, v11.8H, v15.8H // ........................................*......................... + // sqrdmulh v24.8H, v26.8H, v0.H[5] // .....*............................................................ + // mul v26.8H, v26.8H, v0.H[4] // ....*............................................................. + // add v11.8H, v11.8H, v15.8H // .......................................*.......................... + // sqrdmulh v15.8H, v8.8H, v0.H[5] // ............................................*..................... + // mul v8.8H, v8.8H, v0.H[4] // ...........................................*...................... + // mls v26.8H, v24.8H, v7.H[0] // ..........*....................................................... + // sub v24.8H, v12.8H, v11.8H // ..............................................*................... + // add v12.8H, v12.8H, v11.8H // ..........................................*....................... + // mls v8.8H, v15.8H, v7.H[0] // ................................................*................. + // sqrdmulh v11.8H, v24.8H, v0.H[1] // .................................................*................ + // mul v15.8H, v24.8H, v0.H[0] // ..................................................*............... + // sub v24.8H, v16.8H, v26.8H // ...............*.................................................. + // add v26.8H, v16.8H, v26.8H // ...................................................*.............. + // sub v16.8H, v6.8H, v8.8H // ....................................................*............. + // mls v15.8H, v11.8H, v7.H[0] // ......................................................*........... + // sqrdmulh v11.8H, v24.8H, v0.H[1] // ..................*............................................... + // mul v24.8H, v24.8H, v0.H[0] // ...................*.............................................. + // add v8.8H, v6.8H, v8.8H // .....................................................*............ + // sqrdmulh v6.8H, v16.8H, v0.H[1] // .......................................................*.......... + // mul v16.8H, v16.8H, v0.H[0] // ........................................................*......... + // mls v24.8H, v11.8H, v7.H[0] // .......................*.......................................... + // str q9, [x0, #240] // .............................................*.................... + // mls v16.8H, v6.8H, v7.H[0] // ............................................................*..... + // str q15, [x0, #304] // ...........................................................*...... + // str q24, [x0, #368] // ..................................*............................... + // str q16, [x0, #432] // .................................................................* + // str q12, [x0, #48] // ...............................................*.................. + // str q26, [x0, #112] // ...............................................................*.. + // str q8, [x0, #176] // .........................................................*........ + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_clean.S new file mode 100644 index 0000000000..877a5f689f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/ntt_clean.S @@ -0,0 +1,283 @@ +/// +/// Copyright (c) 2022 Arm Limited +/// Copyright (c) 2022 Hanno Becker +/// Copyright (c) 2023 Amin Abdulrahman, Matthias Kannwischer +/// Copyright (c) 2024 The mlkem-native project authors +// SPDX-License-Identifier: MIT +/// +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in all +/// copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +/// SOFTWARE. +/// + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +// Bounds: +// If C is chosen so that |src| < q * C, then |dst| < q * (0.0508 * C + 1/2) +// +// See mlken/reduce.c and test/test_bounds.py for more details. +.macro mulmodq dst, src, const, idx0, idx1 + // Signed barrett multiplication using + // round-to-nearest-even-integer approximation. + // Following https://eprint.iacr.org/2021/986.pdf, this + // is functionally the same as a signed Montgomery multiplication + // with a suitable constant of absolute value < q. + sqrdmulh t2.8h, \src\().8h, \const\().h[\idx1\()] + mul \dst\().8h, \src\().8h, \const\().h[\idx0\()] + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro mulmod dst, src, const, const_twisted + sqrdmulh t2.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, t2.8h, consts.h[0] +.endm + +.macro ct_butterfly a, b, root, idx0, idx1 + mulmodq tmp, \b, \root, \idx0, \idx1 + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro ct_butterfly_v a, b, root, root_twisted + mulmod tmp, \b, \root, \root_twisted + sub \b\().8h, \a\().8h, tmp.8h + add \a\().8h, \a\().8h, tmp.8h +.endm + +.macro load_roots_012 + ldr q_root0, [r01234_ptr], #32 + ldr q_root1, [r01234_ptr, #-16] +.endm + +.macro load_next_roots_34 + ldr q_root0, [r01234_ptr], #16 +.endm + +.macro load_next_roots_56 + ldr q_root0, [r56_ptr], #(6*16) + ldr q_root0_tw, [r56_ptr, #(-6*16 + 1*16)] + ldr q_root1, [r56_ptr, #(-6*16 + 2*16)] + ldr q_root1_tw, [r56_ptr, #(-6*16 + 3*16)] + ldr q_root2, [r56_ptr, #(-6*16 + 4*16)] + ldr q_root2_tw, [r56_ptr, #(-6*16 + 5*16)] +.endm + +.macro transpose4 data + trn1 t0.4s, \data\()0.4s, \data\()1.4s + trn2 t1.4s, \data\()0.4s, \data\()1.4s + trn1 t2.4s, \data\()2.4s, \data\()3.4s + trn2 t3.4s, \data\()2.4s, \data\()3.4s + + trn2 \data\()2.2d, t0.2d, t2.2d + trn2 \data\()3.2d, t1.2d, t3.2d + trn1 \data\()0.2d, t0.2d, t2.2d + trn1 \data\()1.2d, t1.2d, t3.2d +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + // Arguments + in .req x0 // Input/output buffer + r01234_ptr .req x1 // twiddles for layer 0,1,2,3,4 + r56_ptr .req x2 // twiddles for layer 5,6 + + inp .req x3 + count .req x4 + xtmp .req x5 + + data0 .req v8 + data1 .req v9 + data2 .req v10 + data3 .req v11 + data4 .req v12 + data5 .req v13 + data6 .req v14 + data7 .req v15 + + q_data0 .req q8 + q_data1 .req q9 + q_data2 .req q10 + q_data3 .req q11 + q_data4 .req q12 + q_data5 .req q13 + q_data6 .req q14 + q_data7 .req q15 + + root0 .req v0 + root1 .req v1 + root2 .req v2 + root0_tw .req v4 + root1_tw .req v5 + root2_tw .req v6 + + q_root0 .req q0 + q_root1 .req q1 + q_root2 .req q2 + q_root0_tw .req q4 + q_root1_tw .req q5 + q_root2_tw .req q6 + + consts .req v7 + q_consts .req q7 + + tmp .req v24 + t0 .req v25 + t1 .req v26 + t2 .req v27 + t3 .req v28 + + .text + .global MLKEM_ASM_NAMESPACE(ntt_asm_clean) + +/* Literal pool */ +.p2align 4 +c_consts: + .short 3329 + .short 20159 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + .short 0 + +MLKEM_ASM_NAMESPACE(ntt_asm_clean): + push_stack + ldr q_consts, c_consts + + mov inp, in + mov count, #4 + + load_roots_012 + + .p2align 2 + + // Bounds reasoning: + // - There are 7 layers + // - When passing from layer N to layer N+1, each layer-N value + // is modified through the addition/subtraction of a Montgomery + // product of a twiddle of absolute value < q/2 and a layer-N value. + // - Recalling that for C such that |a| < C * q and |t| + // 0 25 + // |------------------------|---- + ldr q21, [x0, #0] // *............................. + ldr q26, [x0, #64] // ..*........................... + ldr q29, [x0, #128] // ....*......................... + ldr q20, [x0, #192] // ......*....................... + ldr q23, [x0, #256] // ........*..................... + ldr q11, [x0, #448] // ..........*................... + mul v2.8H, v23.8H, v0.H[0] // ............*................. + ldr q17, [x0, #320] // .............*................ + mul v15.8H, v11.8H, v0.H[0] // ...............*.............. + ldr q13, [x0, #384] // ................*............. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q21, [x0, #0] // *.............................. + // ldr q26, [x0, #64] // ..*............................ + // ldr q29, [x0, #128] // ....*.......................... + // ldr q20, [x0, #192] // ......*........................ + // ldr q23, [x0, #256] // ........*...................... + // ldr q17, [x0, #320] // .............*................. + // mul v2.8H, v23.8H, v0.H[0] // ............*.................. + // ldr q11, [x0, #448] // ..........*.................... + // ldr q13, [x0, #384] // ................*.............. + // mul v15.8H, v11.8H, v0.H[0] // ...............*............... + + sub count, count, #1 +1: + // Instructions: 76 + // Expected cycles: 84 + // Expected IPC: 0.90 + // + // Cycle bound: 84.0 + // IPC bound: 0.90 + // + // Wall time: 2.36s + // User time: 2.36s + // + // -------------------------------- cycle (expected) ---------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|-------- + sqrdmulh v14.8H, v23.8H, v0.H[1] // *................................................................................... + sqrdmulh v23.8H, v17.8H, v0.H[1] // .*.................................................................................. + mul v17.8H, v17.8H, v0.H[0] // ..*................................................................................. + sqrdmulh v28.8H, v13.8H, v0.H[1] // ...*................................................................................ + mls v2.8H, v14.8H, v7.H[0] // ....*............................................................................... + mul v14.8H, v13.8H, v0.H[0] // .....*.............................................................................. + mls v17.8H, v23.8H, v7.H[0] // ......*............................................................................. + sqrdmulh v23.8H, v11.8H, v0.H[1] // .......*............................................................................ + sub v11.8H, v21.8H, v2.8H // ........*........................................................................... + mls v14.8H, v28.8H, v7.H[0] // .........*.......................................................................... + sub v28.8H, v26.8H, v17.8H // ..........*......................................................................... + add v17.8H, v26.8H, v17.8H // ...........*........................................................................ + add v2.8H, v21.8H, v2.8H // ............*....................................................................... + sub v13.8H, v29.8H, v14.8H // .............*...................................................................... + add v14.8H, v29.8H, v14.8H // ..............*..................................................................... + mls v15.8H, v23.8H, v7.H[0] // ...............*.................................................................... + sqrdmulh v23.8H, v13.8H, v0.H[5] // ................*................................................................... + mul v13.8H, v13.8H, v0.H[4] // .................*.................................................................. + sqrdmulh v21.8H, v14.8H, v0.H[3] // ..................*................................................................. + sub v26.8H, v20.8H, v15.8H // ...................*................................................................ + add v15.8H, v20.8H, v15.8H // ....................*............................................................... + mls v13.8H, v23.8H, v7.H[0] // .....................*.............................................................. + sqrdmulh v23.8H, v26.8H, v0.H[5] // ......................*............................................................. + mul v26.8H, v26.8H, v0.H[4] // .......................*............................................................ + mul v14.8H, v14.8H, v0.H[2] // ........................*........................................................... + sub v29.8H, v11.8H, v13.8H // .........................*.......................................................... + add v11.8H, v11.8H, v13.8H // ..........................*......................................................... + mls v26.8H, v23.8H, v7.H[0] // ...........................*........................................................ + sqrdmulh v23.8H, v15.8H, v0.H[3] // ............................*....................................................... + mul v13.8H, v15.8H, v0.H[2] // .............................*...................................................... + mls v14.8H, v21.8H, v7.H[0] // ..............................*..................................................... + sub v15.8H, v28.8H, v26.8H // ...............................*.................................................... + add v28.8H, v28.8H, v26.8H // ................................*................................................... + mls v13.8H, v23.8H, v7.H[0] // .................................*.................................................. + sub v23.8H, v2.8H, v14.8H // ..................................*................................................. + add v14.8H, v2.8H, v14.8H // ...................................*................................................ + sqrdmulh v2.8H, v28.8H, v1.H[3] // ....................................*............................................... + sub v21.8H, v17.8H, v13.8H // .....................................*.............................................. + add v17.8H, v17.8H, v13.8H // ......................................*............................................. + mul v28.8H, v28.8H, v1.H[2] // .......................................*............................................ + sqrdmulh v13.8H, v21.8H, v1.H[1] // ........................................*........................................... + sqrdmulh v26.8H, v17.8H, v0.H[7] // .........................................*.......................................... + mul v17.8H, v17.8H, v0.H[6] // ..........................................*......................................... + mul v21.8H, v21.8H, v1.H[0] // ...........................................*........................................ + mls v28.8H, v2.8H, v7.H[0] // ............................................*....................................... + sqrdmulh v2.8H, v15.8H, v1.H[5] // .............................................*...................................... + mls v17.8H, v26.8H, v7.H[0] // ..............................................*..................................... + mls v21.8H, v13.8H, v7.H[0] // ...............................................*.................................... + sub v13.8H, v11.8H, v28.8H // ................................................*................................... + add v28.8H, v11.8H, v28.8H // .................................................*.................................. + sub v11.8H, v14.8H, v17.8H // ..................................................*................................. + mul v15.8H, v15.8H, v1.H[4] // ...................................................*................................ + add v14.8H, v14.8H, v17.8H // ....................................................*............................... + sub v17.8H, v23.8H, v21.8H // .....................................................*.............................. + add v23.8H, v23.8H, v21.8H // ......................................................*............................. + mls v15.8H, v2.8H, v7.H[0] // .......................................................*............................ + str q14, [x0], #(16) // ........................................................*........................... + ldr q21, [x0, #0] // .........................................................e.......................... + sub v14.8H, v29.8H, v15.8H // ...........................................................*........................ + add v2.8H, v29.8H, v15.8H // ............................................................*....................... + str q11, [x0, #48] // .............................................................*...................... + ldr q26, [x0, #64] // ..............................................................e..................... + str q23, [x0, #112] // ................................................................*................... + ldr q29, [x0, #128] // .................................................................e.................. + str q17, [x0, #176] // ...................................................................*................ + ldr q20, [x0, #192] // ....................................................................e............... + str q28, [x0, #240] // ......................................................................*............. + ldr q23, [x0, #256] // .......................................................................e............ + str q13, [x0, #304] // .........................................................................*.......... + ldr q17, [x0, #320] // ..........................................................................e......... + str q2, [x0, #368] // ............................................................................*....... + mul v2.8H, v23.8H, v0.H[0] // .............................................................................e...... + str q14, [x0, #432] // ..............................................................................*..... + ldr q11, [x0, #448] // ...............................................................................e.... + ldr q13, [x0, #384] // .................................................................................e.. + mul v15.8H, v11.8H, v0.H[0] // ...................................................................................e + + // ------------------------------------------- cycle (expected) --------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|----- + // ldr q8, [x0, #0] // e..........................'........................................................~..................... + // ldr q9, [x0, #(1*(512/8))] // .....e.....................'.............................................................~................ + // ldr q10, [x0, #(2*(512/8))] // ........e..................'................................................................~............. + // ldr q11, [x0, #(3*(512/8))] // ...........e...............'...................................................................~.......... + // ldr q12, [x0, #(4*(512/8))] // ..............e............'......................................................................~....... + // ldr q13, [x0, #(5*(512/8))] // .................e.........'.........................................................................~.... + // ldr q14, [x0, #(6*(512/8))] // ........................e..'.............................................................................. + // ldr q15, [x0, #(7*(512/8))] // ......................e....'.............................................................................. + // sqrdmulh v27.8h, v12.8h, v0.h[1] // ...........................*.............................................................................. + // mul v24.8h, v12.8h, v0.h[0] // ....................e......'............................................................................~. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...*.......................................................................... + // sub v12.8h, v8.8h, v24.8h // ...........................'.......*...................................................................... + // add v8.8h, v8.8h, v24.8h // ...........................'...........*.................................................................. + // sqrdmulh v27.8h, v13.8h, v0.h[1] // ...........................'*............................................................................. + // mul v24.8h, v13.8h, v0.h[0] // ...........................'.*............................................................................ + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.....*........................................................................ + // sub v13.8h, v9.8h, v24.8h // ...........................'.........*.................................................................... + // add v9.8h, v9.8h, v24.8h // ...........................'..........*................................................................... + // sqrdmulh v27.8h, v14.8h, v0.h[1] // ...........................'..*........................................................................... + // mul v24.8h, v14.8h, v0.h[0] // ...........................'....*......................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'........*..................................................................... + // sub v14.8h, v10.8h, v24.8h // ...........................'............*................................................................. + // add v10.8h, v10.8h, v24.8h // ...........................'.............*................................................................ + // sqrdmulh v27.8h, v15.8h, v0.h[1] // ...........................'......*....................................................................... + // mul v24.8h, v15.8h, v0.h[0] // ..........................e'.............................................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............*............................................................... + // sub v15.8h, v11.8h, v24.8h // ...........................'..................*........................................................... + // add v11.8h, v11.8h, v24.8h // ...........................'...................*.......................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[3] // ...........................'.................*............................................................ + // mul v24.8h, v10.8h, v0.h[2] // ...........................'.......................*...................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................*................................................ + // sub v10.8h, v8.8h, v24.8h // ...........................'.................................*............................................ + // add v8.8h, v8.8h, v24.8h // ...........................'..................................*........................................... + // sqrdmulh v27.8h, v11.8h, v0.h[3] // ...........................'...........................*.................................................. + // mul v24.8h, v11.8h, v0.h[2] // ...........................'............................*................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'................................*............................................. + // sub v11.8h, v9.8h, v24.8h // ...........................'....................................*......................................... + // add v9.8h, v9.8h, v24.8h // ...........................'.....................................*........................................ + // sqrdmulh v27.8h, v14.8h, v0.h[5] // ...........................'...............*.............................................................. + // mul v24.8h, v14.8h, v0.h[4] // ...........................'................*............................................................. + // mls v24.8h, v27.8h, v7.h[0] // ...........................'....................*......................................................... + // sub v14.8h, v12.8h, v24.8h // ...........................'........................*..................................................... + // add v12.8h, v12.8h, v24.8h // ...........................'.........................*.................................................... + // sqrdmulh v27.8h, v15.8h, v0.h[5] // ...........................'.....................*........................................................ + // mul v24.8h, v15.8h, v0.h[4] // ...........................'......................*....................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..........................*................................................... + // sub v15.8h, v13.8h, v24.8h // ...........................'..............................*............................................... + // add v13.8h, v13.8h, v24.8h // ...........................'...............................*.............................................. + // sqrdmulh v27.8h, v9.8h, v0.h[7] // ...........................'........................................*..................................... + // mul v24.8h, v9.8h, v0.h[6] // ...........................'.........................................*.................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'.............................................*................................ + // sub v9.8h, v8.8h, v24.8h // ...........................'.................................................*............................ + // add v8.8h, v8.8h, v24.8h // ...........................'...................................................*.......................... + // sqrdmulh v27.8h, v11.8h, v1.h[1] // ...........................'.......................................*...................................... + // mul v24.8h, v11.8h, v1.h[0] // ...........................'..........................................*................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'..............................................*............................... + // sub v11.8h, v10.8h, v24.8h // ...........................'....................................................*......................... + // add v10.8h, v10.8h, v24.8h // ...........................'.....................................................*........................ + // sqrdmulh v27.8h, v13.8h, v1.h[3] // ...........................'...................................*.......................................... + // mul v24.8h, v13.8h, v1.h[2] // ...........................'......................................*....................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'...........................................*.................................. + // sub v13.8h, v12.8h, v24.8h // ...........................'...............................................*.............................. + // add v12.8h, v12.8h, v24.8h // ...........................'................................................*............................. + // sqrdmulh v27.8h, v15.8h, v1.h[5] // ...........................'............................................*................................. + // mul v24.8h, v15.8h, v1.h[4] // ...........................'..................................................*........................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................'......................................................*....................... + // sub v15.8h, v14.8h, v24.8h // ..~........................'..........................................................*................... + // add v14.8h, v14.8h, v24.8h // ...~.......................'...........................................................*.................. + // str q8, [x0], #(16) // ...........................'.......................................................*...................... + // str q9, [x0, #(-16 + 1*(512/8))] // ....~......................'............................................................*................. + // str q10, [x0, #(-16 + 2*(512/8))] // .......~...................'...............................................................*.............. + // str q11, [x0, #(-16 + 3*(512/8))] // ..........~................'..................................................................*........... + // str q12, [x0, #(-16 + 4*(512/8))] // .............~.............'.....................................................................*........ + // str q13, [x0, #(-16 + 5*(512/8))] // ................~..........'........................................................................*..... + // str q14, [x0, #(-16 + 6*(512/8))] // ...................~.......'...........................................................................*.. + // str q15, [x0, #(-16 + 7*(512/8))] // .....................~.....'.............................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 66 + // Expected cycles: 67 + // Expected IPC: 0.99 + // + // Cycle bound: 67.0 + // IPC bound: 0.99 + // + // Wall time: 7.51s + // User time: 7.51s + // + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + sqrdmulh v27.8H, v11.8H, v0.H[1] // *.................................................................. + mul v8.8H, v13.8H, v0.H[0] // .*................................................................. + sqrdmulh v22.8H, v13.8H, v0.H[1] // ..*................................................................ + mul v11.8H, v17.8H, v0.H[0] // ...*............................................................... + mls v15.8H, v27.8H, v7.H[0] // ....*.............................................................. + sqrdmulh v28.8H, v17.8H, v0.H[1] // .....*............................................................. + mls v8.8H, v22.8H, v7.H[0] // ......*............................................................ + sqrdmulh v5.8H, v23.8H, v0.H[1] // .......*........................................................... + add v16.8H, v20.8H, v15.8H // ........*.......................................................... + mls v11.8H, v28.8H, v7.H[0] // .........*......................................................... + sub v6.8H, v29.8H, v8.8H // ..........*........................................................ + sqrdmulh v17.8H, v16.8H, v0.H[3] // ...........*....................................................... + mul v23.8H, v16.8H, v0.H[2] // ............*...................................................... + mul v13.8H, v6.8H, v0.H[4] // .............*..................................................... + sqrdmulh v28.8H, v6.8H, v0.H[5] // ..............*.................................................... + mls v2.8H, v5.8H, v7.H[0] // ...............*................................................... + mls v23.8H, v17.8H, v7.H[0] // ................*.................................................. + add v27.8H, v26.8H, v11.8H // .................*................................................. + mls v13.8H, v28.8H, v7.H[0] // ..................*................................................ + sub v9.8H, v21.8H, v2.8H // ...................*............................................... + add v18.8H, v29.8H, v8.8H // ....................*.............................................. + sub v14.8H, v27.8H, v23.8H // .....................*............................................. + add v29.8H, v9.8H, v13.8H // ......................*............................................ + sub v30.8H, v9.8H, v13.8H // .......................*........................................... + mul v28.8H, v14.8H, v1.H[0] // ........................*.......................................... + sqrdmulh v9.8H, v18.8H, v0.H[3] // .........................*......................................... + mul v22.8H, v18.8H, v0.H[2] // ..........................*........................................ + sqrdmulh v17.8H, v14.8H, v1.H[1] // ...........................*....................................... + sub v14.8H, v20.8H, v15.8H // ............................*...................................... + add v24.8H, v21.8H, v2.8H // .............................*..................................... + mls v22.8H, v9.8H, v7.H[0] // ..............................*.................................... + sqrdmulh v9.8H, v14.8H, v0.H[5] // ...............................*................................... + mul v13.8H, v14.8H, v0.H[4] // ................................*.................................. + mls v28.8H, v17.8H, v7.H[0] // .................................*................................. + sub v5.8H, v24.8H, v22.8H // ..................................*................................ + sub v2.8H, v26.8H, v11.8H // ...................................*............................... + mls v13.8H, v9.8H, v7.H[0] // ....................................*.............................. + sub v17.8H, v5.8H, v28.8H // .....................................*............................. + add v14.8H, v5.8H, v28.8H // ......................................*............................ + add v28.8H, v27.8H, v23.8H // .......................................*........................... + str q17, [x0, #192] // ........................................*.......................... + add v17.8H, v2.8H, v13.8H // .........................................*......................... + str q14, [x0, #128] // ..........................................*........................ + sub v13.8H, v2.8H, v13.8H // ...........................................*....................... + sqrdmulh v26.8H, v17.8H, v1.H[3] // ............................................*...................... + mul v15.8H, v17.8H, v1.H[2] // .............................................*..................... + add v5.8H, v24.8H, v22.8H // ..............................................*.................... + sqrdmulh v23.8H, v13.8H, v1.H[5] // ...............................................*................... + mul v13.8H, v13.8H, v1.H[4] // ................................................*.................. + mls v15.8H, v26.8H, v7.H[0] // .................................................*................. + sqrdmulh v14.8H, v28.8H, v0.H[7] // ..................................................*................ + mul v17.8H, v28.8H, v0.H[6] // ...................................................*............... + mls v13.8H, v23.8H, v7.H[0] // ....................................................*.............. + add v6.8H, v29.8H, v15.8H // .....................................................*............. + sub v28.8H, v29.8H, v15.8H // ......................................................*............ + mls v17.8H, v14.8H, v7.H[0] // .......................................................*........... + str q6, [x0, #256] // ........................................................*.......... + add v14.8H, v30.8H, v13.8H // .........................................................*......... + str q28, [x0, #320] // ..........................................................*........ + sub v23.8H, v30.8H, v13.8H // ...........................................................*....... + str q14, [x0, #384] // ............................................................*...... + add v3.8H, v5.8H, v17.8H // .............................................................*..... + str q23, [x0, #448] // ..............................................................*.... + sub v28.8H, v5.8H, v17.8H // ...............................................................*... + str q3, [x0], #(16) // ................................................................*.. + str q28, [x0, #48] // ..................................................................* + + // ------------------------ cycle (expected) ------------------------> + // 0 25 50 + // |------------------------|------------------------|---------------- + // sqrdmulh v14.8H, v23.8H, v0.H[1] // .......*........................................................... + // sqrdmulh v23.8H, v17.8H, v0.H[1] // .....*............................................................. + // mul v17.8H, v17.8H, v0.H[0] // ...*............................................................... + // sqrdmulh v28.8H, v13.8H, v0.H[1] // ..*................................................................ + // mls v2.8H, v14.8H, v7.H[0] // ...............*................................................... + // mul v14.8H, v13.8H, v0.H[0] // .*................................................................. + // mls v17.8H, v23.8H, v7.H[0] // .........*......................................................... + // sqrdmulh v23.8H, v11.8H, v0.H[1] // *.................................................................. + // sub v11.8H, v21.8H, v2.8H // ...................*............................................... + // mls v14.8H, v28.8H, v7.H[0] // ......*............................................................ + // sub v28.8H, v26.8H, v17.8H // ...................................*............................... + // add v17.8H, v26.8H, v17.8H // .................*................................................. + // add v2.8H, v21.8H, v2.8H // .............................*..................................... + // sub v13.8H, v29.8H, v14.8H // ..........*........................................................ + // add v14.8H, v29.8H, v14.8H // ....................*.............................................. + // mls v15.8H, v23.8H, v7.H[0] // ....*.............................................................. + // sqrdmulh v23.8H, v13.8H, v0.H[5] // ..............*.................................................... + // mul v13.8H, v13.8H, v0.H[4] // .............*..................................................... + // sqrdmulh v21.8H, v14.8H, v0.H[3] // .........................*......................................... + // sub v26.8H, v20.8H, v15.8H // ............................*...................................... + // add v15.8H, v20.8H, v15.8H // ........*.......................................................... + // mls v13.8H, v23.8H, v7.H[0] // ..................*................................................ + // sqrdmulh v23.8H, v26.8H, v0.H[5] // ...............................*................................... + // mul v26.8H, v26.8H, v0.H[4] // ................................*.................................. + // mul v14.8H, v14.8H, v0.H[2] // ..........................*........................................ + // sub v29.8H, v11.8H, v13.8H // .......................*........................................... + // add v11.8H, v11.8H, v13.8H // ......................*............................................ + // mls v26.8H, v23.8H, v7.H[0] // ....................................*.............................. + // sqrdmulh v23.8H, v15.8H, v0.H[3] // ...........*....................................................... + // mul v13.8H, v15.8H, v0.H[2] // ............*...................................................... + // mls v14.8H, v21.8H, v7.H[0] // ..............................*.................................... + // sub v15.8H, v28.8H, v26.8H // ...........................................*....................... + // add v28.8H, v28.8H, v26.8H // .........................................*......................... + // mls v13.8H, v23.8H, v7.H[0] // ................*.................................................. + // sub v23.8H, v2.8H, v14.8H // ..................................*................................ + // add v14.8H, v2.8H, v14.8H // ..............................................*.................... + // sqrdmulh v2.8H, v28.8H, v1.H[3] // ............................................*...................... + // sub v21.8H, v17.8H, v13.8H // .....................*............................................. + // add v17.8H, v17.8H, v13.8H // .......................................*........................... + // mul v28.8H, v28.8H, v1.H[2] // .............................................*..................... + // sqrdmulh v13.8H, v21.8H, v1.H[1] // ...........................*....................................... + // sqrdmulh v26.8H, v17.8H, v0.H[7] // ..................................................*................ + // mul v17.8H, v17.8H, v0.H[6] // ...................................................*............... + // mul v21.8H, v21.8H, v1.H[0] // ........................*.......................................... + // mls v28.8H, v2.8H, v7.H[0] // .................................................*................. + // sqrdmulh v2.8H, v15.8H, v1.H[5] // ...............................................*................... + // mls v17.8H, v26.8H, v7.H[0] // .......................................................*........... + // mls v21.8H, v13.8H, v7.H[0] // .................................*................................. + // sub v13.8H, v11.8H, v28.8H // ......................................................*............ + // add v28.8H, v11.8H, v28.8H // .....................................................*............. + // sub v11.8H, v14.8H, v17.8H // ...............................................................*... + // mul v15.8H, v15.8H, v1.H[4] // ................................................*.................. + // add v14.8H, v14.8H, v17.8H // .............................................................*..... + // sub v17.8H, v23.8H, v21.8H // .....................................*............................. + // add v23.8H, v23.8H, v21.8H // ......................................*............................ + // mls v15.8H, v2.8H, v7.H[0] // ....................................................*.............. + // str q14, [x0], #(16) // ................................................................*.. + // sub v14.8H, v29.8H, v15.8H // ...........................................................*....... + // add v2.8H, v29.8H, v15.8H // .........................................................*......... + // str q11, [x0, #48] // ..................................................................* + // str q23, [x0, #112] // ..........................................*........................ + // str q17, [x0, #176] // ........................................*.......................... + // str q28, [x0, #240] // ........................................................*.......... + // str q13, [x0, #304] // ..........................................................*........ + // str q2, [x0, #368] // ............................................................*...... + // str q14, [x0, #432] // ..............................................................*.... + + + mov in, inp + mov count, #8 + + .p2align 2 + // Instructions: 24 + // Expected cycles: 31 + // Expected IPC: 0.77 + // + // Cycle bound: 31.0 + // IPC bound: 0.77 + // + // Wall time: 0.08s + // User time: 0.08s + // + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + ldr q2, [x1], #16 // *.............................. + ldr q14, [x0, #48] // ..*............................ + ldr q1, [x0, #32] // ....*.......................... + mul v17.8H, v14.8H, v2.H[0] // ......*........................ + sqrdmulh v14.8H, v14.8H, v2.H[1] // .......*....................... + mul v8.8H, v1.8H, v2.H[0] // ........*...................... + ldr q23, [x0, #16] // .........*..................... + mls v17.8H, v14.8H, v7.H[0] // ...........*................... + sqrdmulh v1.8H, v1.8H, v2.H[1] // ............*.................. + ldr q30, [x2], #(6*16) // .............*................. + sub v14.8H, v23.8H, v17.8H // ...............*............... + add v10.8H, v23.8H, v17.8H // ................*.............. + mls v8.8H, v1.8H, v7.H[0] // .................*............. + sqrdmulh v1.8H, v14.8H, v2.H[5] // ..................*............ + mul v14.8H, v14.8H, v2.H[4] // ...................*........... + ldr q27, [x0, #0] // ....................*.......... + mul v23.8H, v10.8H, v2.H[2] // ......................*........ + mls v14.8H, v1.8H, v7.H[0] // .......................*....... + sub v1.8H, v27.8H, v8.8H // ........................*...... + ldr q28, [x2, #-64] // .........................*..... + add v12.8H, v1.8H, v14.8H // ...........................*... + sqrdmulh v21.8H, v10.8H, v2.H[3] // ............................*.. + sub v5.8H, v1.8H, v14.8H // .............................*. + ldr q13, [x2, #-16] // ..............................* + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q19, [x0, #48] // ..*............................ + // ldr q1, [x1], #16 // *.............................. + // mul v4.8H, v19.8H, v1.H[0] // ......*........................ + // sqrdmulh v19.8H, v19.8H, v1.H[1] // .......*....................... + // ldr q25, [x0, #16] // .........*..................... + // mls v4.8H, v19.8H, v7.H[0] // ...........*................... + // sub v24.8H, v25.8H, v4.8H // ...............*............... + // add v4.8H, v25.8H, v4.8H // ................*.............. + // sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................*............ + // mul v20.8H, v24.8H, v1.H[4] // ...................*........... + // sqrdmulh v21.8H, v4.8H, v1.H[3] // ............................*.. + // mls v20.8H, v23.8H, v7.H[0] // .......................*....... + // mul v23.8H, v4.8H, v1.H[2] // ......................*........ + // ldr q31, [x0, #32] // ....*.......................... + // mul v8.8H, v31.8H, v1.H[0] // ........*...................... + // sqrdmulh v1.8H, v31.8H, v1.H[1] // ............*.................. + // mls v8.8H, v1.8H, v7.H[0] // .................*............. + // ldr q27, [x0, #0] // ....................*.......... + // sub v10.8H, v27.8H, v8.8H // ........................*...... + // add v12.8H, v10.8H, v20.8H // ...........................*... + // ldr q30, [x2], #(6*16) // .............*................. + // ldr q28, [x2, #-64] // .........................*..... + // sub v5.8H, v10.8H, v20.8H // .............................*. + // ldr q13, [x2, #-16] // ..............................* + + sub count, count, #1 +1: + // Instructions: 71 + // Expected cycles: 82 + // Expected IPC: 0.87 + // + // Cycle bound: 82.0 + // IPC bound: 0.87 + // + // Wall time: 11.93s + // User time: 11.93s + // + // ------------------------------- cycle (expected) --------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + ldr q19, [x0, #112] // e................................................................................. + ldr q1, [x1], #16 // ..e............................................................................... + mls v23.8H, v21.8H, v7.H[0] // ....*............................................................................. + add v6.8H, v27.8H, v8.8H // .....*............................................................................ + mul v4.8H, v19.8H, v1.H[0] // ......e........................................................................... + sqrdmulh v19.8H, v19.8H, v1.H[1] // .......e.......................................................................... + ldr q25, [x0, #80] // ........e......................................................................... + trn1 v11.4S, v12.4S, v5.4S // ..........*....................................................................... + mls v4.8H, v19.8H, v7.H[0] // ...........e...................................................................... + sub v0.8H, v6.8H, v23.8H // ............*..................................................................... + ldr q16, [x2, #-80] // .............*.................................................................... + sub v24.8H, v25.8H, v4.8H // ...............e.................................................................. + add v26.8H, v6.8H, v23.8H // ................*................................................................. + add v4.8H, v25.8H, v4.8H // .................e................................................................ + sqrdmulh v23.8H, v24.8H, v1.H[5] // ..................e............................................................... + mul v20.8H, v24.8H, v1.H[4] // ...................e.............................................................. + sqrdmulh v21.8H, v4.8H, v1.H[3] // ....................e............................................................. + trn1 v27.4S, v26.4S, v0.4S // .....................*............................................................ + trn2 v25.4S, v12.4S, v5.4S // ......................*........................................................... + mls v20.8H, v23.8H, v7.H[0] // .......................e.......................................................... + mul v23.8H, v4.8H, v1.H[2] // ........................e......................................................... + ldr q31, [x0, #96] // .........................e........................................................ + trn2 v12.4S, v26.4S, v0.4S // ...........................*...................................................... + trn2 v19.2D, v27.2D, v11.2D // ............................*..................................................... + mul v8.8H, v31.8H, v1.H[0] // .............................e.................................................... + sqrdmulh v1.8H, v31.8H, v1.H[1] // ..............................e................................................... + trn2 v10.2D, v12.2D, v25.2D // ...............................*.................................................. + sqrdmulh v0.8H, v19.8H, v16.8H // ................................*................................................. + sqrdmulh v18.8H, v10.8H, v16.8H // .................................*................................................ + trn1 v16.2D, v27.2D, v11.2D // ..................................*............................................... + trn1 v2.2D, v12.2D, v25.2D // ...................................*.............................................. + mul v12.8H, v10.8H, v30.8H // ....................................*............................................. + mul v10.8H, v19.8H, v30.8H // .....................................*............................................ + mls v8.8H, v1.8H, v7.H[0] // ......................................e........................................... + ldr q14, [x2, #-48] // .......................................*.......................................... + mls v10.8H, v0.8H, v7.H[0] // .........................................*........................................ + mls v12.8H, v18.8H, v7.H[0] // ..........................................*....................................... + ldr q27, [x0, #64] // ...........................................e...................................... + add v9.8H, v16.8H, v10.8H // .............................................*.................................... + sub v16.8H, v16.8H, v10.8H // ..............................................*................................... + sub v25.8H, v2.8H, v12.8H // ...............................................*.................................. + add v30.8H, v2.8H, v12.8H // ................................................*................................. + sub v10.8H, v27.8H, v8.8H // .................................................e................................ + sqrdmulh v22.8H, v25.8H, v13.8H // ..................................................*............................... + sqrdmulh v13.8H, v30.8H, v14.8H // ...................................................*.............................. + ldr q14, [x2, #-32] // ....................................................*............................. + add v12.8H, v10.8H, v20.8H // ......................................................e........................... + mul v5.8H, v30.8H, v28.8H // .......................................................*.......................... + mul v26.8H, v25.8H, v14.8H // ........................................................*......................... + ldr q30, [x2], #(6*16) // .........................................................e........................ + mls v5.8H, v13.8H, v7.H[0] // ...........................................................*...................... + mls v26.8H, v22.8H, v7.H[0] // ............................................................*..................... + ldr q28, [x2, #-64] // .............................................................e.................... + add v13.8H, v9.8H, v5.8H // ...............................................................*.................. + sub v9.8H, v9.8H, v5.8H // ................................................................*................. + sub v5.8H, v16.8H, v26.8H // .................................................................*................ + add v25.8H, v16.8H, v26.8H // ..................................................................*............... + trn1 v15.4S, v13.4S, v9.4S // ...................................................................*.............. + trn2 v3.4S, v13.4S, v9.4S // ....................................................................*............. + trn1 v13.4S, v25.4S, v5.4S // .....................................................................*............ + trn2 v31.4S, v25.4S, v5.4S // ......................................................................*........... + sub v5.8H, v10.8H, v20.8H // .......................................................................e.......... + trn1 v2.2D, v15.2D, v13.2D // ........................................................................*......... + trn2 v9.2D, v15.2D, v13.2D // .........................................................................*........ + str q2, [x0], #(16*4) // ..........................................................................*....... + trn1 v29.2D, v3.2D, v31.2D // ...........................................................................*...... + str q9, [x0, #-32] // ............................................................................*..... + trn2 v9.2D, v3.2D, v31.2D // .............................................................................*.... + str q29, [x0, #-48] // ..............................................................................*... + ldr q13, [x2, #-16] // ...............................................................................e.. + str q9, [x0, #-16] // .................................................................................* + + // ------------------------------------------------------------------------ cycle (expected) -------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------- + // ldr q8, [x0, #(16*0)] // ...........................................e......................................'..........................................~...................................... + // ldr q9, [x0, #(16*1)] // ........e.........................................................................'.......~......................................................................... + // ldr q10, [x0, #(16*2)] // .........................e........................................................'........................~........................................................ + // ldr q11, [x0, #(16*3)] // e.................................................................................~................................................................................. + // ldr q0, [x1], #16 // ..e...............................................................................'.~............................................................................... + // sqrdmulh v27.8h, v10.8h, v0.h[1] // ..............................e...................................................'.............................~................................................... + // mul v24.8h, v10.8h, v0.h[0] // .............................e....................................................'............................~.................................................... + // mls v24.8h, v27.8h, v7.h[0] // ......................................e...........................................'.....................................~........................................... + // sub v10.8h, v8.8h, v24.8h // .................................................e................................'................................................~................................ + // add v8.8h, v8.8h, v24.8h // .....~............................................................................'....*............................................................................ + // sqrdmulh v27.8h, v11.8h, v0.h[1] // .......e..........................................................................'......~.......................................................................... + // mul v24.8h, v11.8h, v0.h[0] // ......e...........................................................................'.....~........................................................................... + // mls v24.8h, v27.8h, v7.h[0] // ...........e......................................................................'..........~...................................................................... + // sub v11.8h, v9.8h, v24.8h // ...............e..................................................................'..............~.................................................................. + // add v9.8h, v9.8h, v24.8h // .................e................................................................'................~................................................................ + // sqrdmulh v27.8h, v9.8h, v0.h[3] // ....................e.............................................................'...................~............................................................. + // mul v24.8h, v9.8h, v0.h[2] // ........................e.........................................................'.......................~......................................................... + // mls v24.8h, v27.8h, v7.h[0] // ....~.............................................................................'...*............................................................................. + // sub v9.8h, v8.8h, v24.8h // ............~.....................................................................'...........*..................................................................... + // add v8.8h, v8.8h, v24.8h // ................~.................................................................'...............*................................................................. + // sqrdmulh v27.8h, v11.8h, v0.h[5] // ..................e...............................................................'.................~............................................................... + // mul v24.8h, v11.8h, v0.h[4] // ...................e..............................................................'..................~.............................................................. + // mls v24.8h, v27.8h, v7.h[0] // .......................e..........................................................'......................~.......................................................... + // sub v11.8h, v10.8h, v24.8h // .......................................................................e..........'......................................................................~.......... + // add v10.8h, v10.8h, v24.8h // ......................................................e...........................'.....................................................~........................... + // trn1 v25.4s, v8.4s, v9.4s // .....................~............................................................'....................*............................................................ + // trn2 v26.4s, v8.4s, v9.4s // ...........................~......................................................'..........................*...................................................... + // trn1 v27.4s, v10.4s, v11.4s // ..........~.......................................................................'.........*....................................................................... + // trn2 v28.4s, v10.4s, v11.4s // ......................~...........................................................'.....................*........................................................... + // trn2 v10.2d, v25.2d, v27.2d // ............................~.....................................................'...........................*..................................................... + // trn2 v11.2d, v26.2d, v28.2d // ...............................~..................................................'..............................*.................................................. + // trn1 v8.2d, v25.2d, v27.2d // ..................................~...............................................'.................................*............................................... + // trn1 v9.2d, v26.2d, v28.2d // ...................................~..............................................'..................................*.............................................. + // ldr q0, [x2], #(6*16) // .........................................................e........................'........................................................~........................ + // ldr q4, [x2, #(-6*16 + 1*16)] // .............~....................................................................'............*.................................................................... + // ldr q1, [x2, #(-6*16 + 2*16)] // .............................................................e....................'............................................................~.................... + // ldr q5, [x2, #(-6*16 + 3*16)] // .......................................~..........................................'......................................*.......................................... + // ldr q2, [x2, #(-6*16 + 4*16)] // ....................................................~.............................'...................................................*............................. + // ldr q6, [x2, #(-6*16 + 5*16)] // ...............................................................................e..'..............................................................................~.. + // sqrdmulh v27.8h, v10.8h, v4.8h // ................................~.................................................'...............................*................................................. + // mul v24.8h, v10.8h, v0.8h // .....................................~............................................'....................................*............................................ + // mls v24.8h, v27.8h, v7.h[0] // .........................................~........................................'........................................*........................................ + // sub v10.8h, v8.8h, v24.8h // ..............................................~...................................'.............................................*................................... + // add v8.8h, v8.8h, v24.8h // .............................................~....................................'............................................*.................................... + // sqrdmulh v27.8h, v11.8h, v4.8h // .................................~................................................'................................*................................................ + // mul v24.8h, v11.8h, v0.8h // ....................................~.............................................'...................................*............................................. + // mls v24.8h, v27.8h, v7.h[0] // ..........................................~.......................................'.........................................*....................................... + // sub v11.8h, v9.8h, v24.8h // ...............................................~..................................'..............................................*.................................. + // add v9.8h, v9.8h, v24.8h // ................................................~.................................'...............................................*................................. + // sqrdmulh v27.8h, v9.8h, v5.8h // ...................................................~..............................'..................................................*.............................. + // mul v24.8h, v9.8h, v1.8h // .......................................................~..........................'......................................................*.......................... + // mls v24.8h, v27.8h, v7.h[0] // ...........................................................~......................'..........................................................*...................... + // sub v9.8h, v8.8h, v24.8h // ................................................................~.................'...............................................................*................. + // add v8.8h, v8.8h, v24.8h // ...............................................................~..................'..............................................................*.................. + // sqrdmulh v27.8h, v11.8h, v6.8h // ..................................................~...............................'.................................................*............................... + // mul v24.8h, v11.8h, v2.8h // ........................................................~.........................'.......................................................*......................... + // mls v24.8h, v27.8h, v7.h[0] // ............................................................~.....................'...........................................................*..................... + // sub v11.8h, v10.8h, v24.8h // .................................................................~................'................................................................*................ + // add v10.8h, v10.8h, v24.8h // ..................................................................~...............'.................................................................*............... + // trn1 v25.4s, v8.4s, v9.4s // ...................................................................~..............'..................................................................*.............. + // trn2 v26.4s, v8.4s, v9.4s // ....................................................................~.............'...................................................................*............. + // trn1 v27.4s, v10.4s, v11.4s // .....................................................................~............'....................................................................*............ + // trn2 v28.4s, v10.4s, v11.4s // ......................................................................~...........'.....................................................................*........... + // trn2 v10.2d, v25.2d, v27.2d // .........................................................................~........'........................................................................*........ + // trn2 v11.2d, v26.2d, v28.2d // .............................................................................~....'............................................................................*.... + // trn1 v8.2d, v25.2d, v27.2d // ........................................................................~.........'.......................................................................*......... + // trn1 v9.2d, v26.2d, v28.2d // ...........................................................................~......'..........................................................................*...... + // str q8, [x0], #(16*4) // ..........................................................................~.......'.........................................................................*....... + // str q9, [x0, #(-16*3)] // ..............................................................................~...'.............................................................................*... + // str q10, [x0, #(-16*2)] // ............................................................................~.....'...........................................................................*..... + // str q11, [x0, #(-16*1)] // .................................................................................~'................................................................................* + + sub count, count, 1 + cbnz count, 1b + // Instructions: 47 + // Expected cycles: 52 + // Expected IPC: 0.90 + // + // Cycle bound: 52.0 + // IPC bound: 0.90 + // + // Wall time: 5.32s + // User time: 5.32s + // + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + mls v23.8H, v21.8H, v7.H[0] // *................................................... + add v14.8H, v27.8H, v8.8H // .*.................................................. + ldr q1, [x2, #-32] // ..*................................................. + add v17.8H, v14.8H, v23.8H // ....*............................................... + sub v23.8H, v14.8H, v23.8H // .....*.............................................. + trn2 v11.4S, v12.4S, v5.4S // ......*............................................. + trn1 v27.4S, v12.4S, v5.4S // .......*............................................ + trn2 v2.4S, v17.4S, v23.4S // ........*........................................... + ldr q26, [x2, #-80] // .........*.......................................... + trn2 v14.2D, v2.2D, v11.2D // ...........*........................................ + trn1 v15.4S, v17.4S, v23.4S // ............*....................................... + mul v5.8H, v14.8H, v30.8H // .............*...................................... + sqrdmulh v23.8H, v14.8H, v26.8H // ..............*..................................... + trn2 v17.2D, v15.2D, v27.2D // ...............*.................................... + trn1 v14.2D, v2.2D, v11.2D // ................*................................... + mul v21.8H, v17.8H, v30.8H // .................*.................................. + mls v5.8H, v23.8H, v7.H[0] // ..................*................................. + sqrdmulh v17.8H, v17.8H, v26.8H // ...................*................................ + ldr q2, [x2, #-48] // ....................*............................... + sub v23.8H, v14.8H, v5.8H // ......................*............................. + add v14.8H, v14.8H, v5.8H // .......................*............................ + mls v21.8H, v17.8H, v7.H[0] // ........................*........................... + mul v1.8H, v23.8H, v1.8H // .........................*.......................... + sqrdmulh v17.8H, v23.8H, v13.8H // ..........................*......................... + mul v23.8H, v14.8H, v28.8H // ...........................*........................ + sqrdmulh v14.8H, v14.8H, v2.8H // ............................*....................... + trn1 v28.2D, v15.2D, v27.2D // .............................*...................... + mls v1.8H, v17.8H, v7.H[0] // ..............................*..................... + sub v11.8H, v28.8H, v21.8H // ...............................*.................... + mls v23.8H, v14.8H, v7.H[0] // ................................*................... + add v17.8H, v28.8H, v21.8H // .................................*.................. + sub v14.8H, v11.8H, v1.8H // ..................................*................. + add v1.8H, v11.8H, v1.8H // ...................................*................ + sub v28.8H, v17.8H, v23.8H // ....................................*............... + add v2.8H, v17.8H, v23.8H // .....................................*.............. + trn1 v23.4S, v1.4S, v14.4S // ......................................*............. + trn2 v14.4S, v1.4S, v14.4S // .......................................*............ + trn2 v17.4S, v2.4S, v28.4S // ........................................*........... + trn1 v28.4S, v2.4S, v28.4S // .........................................*.......... + trn2 v1.2D, v17.2D, v14.2D // ...........................................*........ + trn1 v14.2D, v17.2D, v14.2D // ............................................*....... + str q1, [x0, #48] // .............................................*...... + trn2 v1.2D, v28.2D, v23.2D // ..............................................*..... + str q14, [x0, #16] // ...............................................*.... + trn1 v14.2D, v28.2D, v23.2D // ................................................*... + str q1, [x0, #32] // .................................................*.. + str q14, [x0], #(16*4) // ...................................................* + + // ---------------- cycle (expected) -----------------> + // 0 25 50 + // |------------------------|------------------------|- + // mls v23.8H, v21.8H, v7.H[0] // *................................................... + // add v6.8H, v27.8H, v8.8H // .*.................................................. + // trn1 v11.4S, v12.4S, v5.4S // .......*............................................ + // sub v0.8H, v6.8H, v23.8H // .....*.............................................. + // ldr q16, [x2, #-80] // .........*.......................................... + // add v26.8H, v6.8H, v23.8H // ....*............................................... + // trn1 v27.4S, v26.4S, v0.4S // ............*....................................... + // trn2 v25.4S, v12.4S, v5.4S // ......*............................................. + // trn2 v12.4S, v26.4S, v0.4S // ........*........................................... + // trn2 v19.2D, v27.2D, v11.2D // ...............*.................................... + // trn2 v10.2D, v12.2D, v25.2D // ...........*........................................ + // sqrdmulh v0.8H, v19.8H, v16.8H // ...................*................................ + // sqrdmulh v18.8H, v10.8H, v16.8H // ..............*..................................... + // trn1 v16.2D, v27.2D, v11.2D // .............................*...................... + // trn1 v2.2D, v12.2D, v25.2D // ................*................................... + // mul v12.8H, v10.8H, v30.8H // .............*...................................... + // mul v10.8H, v19.8H, v30.8H // .................*.................................. + // ldr q14, [x2, #-48] // ....................*............................... + // mls v10.8H, v0.8H, v7.H[0] // ........................*........................... + // mls v12.8H, v18.8H, v7.H[0] // ..................*................................. + // add v9.8H, v16.8H, v10.8H // .................................*.................. + // sub v16.8H, v16.8H, v10.8H // ...............................*.................... + // sub v25.8H, v2.8H, v12.8H // ......................*............................. + // add v30.8H, v2.8H, v12.8H // .......................*............................ + // sqrdmulh v22.8H, v25.8H, v13.8H // ..........................*......................... + // sqrdmulh v13.8H, v30.8H, v14.8H // ............................*....................... + // ldr q14, [x2, #-32] // ..*................................................. + // mul v5.8H, v30.8H, v28.8H // ...........................*........................ + // mul v26.8H, v25.8H, v14.8H // .........................*.......................... + // mls v5.8H, v13.8H, v7.H[0] // ................................*................... + // mls v26.8H, v22.8H, v7.H[0] // ..............................*..................... + // add v13.8H, v9.8H, v5.8H // .....................................*.............. + // sub v9.8H, v9.8H, v5.8H // ....................................*............... + // sub v5.8H, v16.8H, v26.8H // ..................................*................. + // add v25.8H, v16.8H, v26.8H // ...................................*................ + // trn1 v15.4S, v13.4S, v9.4S // .........................................*.......... + // trn2 v3.4S, v13.4S, v9.4S // ........................................*........... + // trn1 v13.4S, v25.4S, v5.4S // ......................................*............. + // trn2 v31.4S, v25.4S, v5.4S // .......................................*............ + // trn1 v2.2D, v15.2D, v13.2D // ................................................*... + // trn2 v9.2D, v15.2D, v13.2D // ..............................................*..... + // str q2, [x0], #(16*4) // ...................................................* + // trn1 v29.2D, v3.2D, v31.2D // ............................................*....... + // str q9, [x0, #-32] // .................................................*.. + // trn2 v9.2D, v3.2D, v31.2D // ...........................................*........ + // str q29, [x0, #-48] // ...............................................*.... + // str q9, [x0, #-16] // .............................................*...... + + + pop_stack + ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/opt_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/opt_impl.h new file mode 100644 index 0000000000..b226740261 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/opt_impl.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include "arith_native_aarch64.h" + +#include "poly.h" +#include "polyvec.h" + +/* Set of primitives that this backend replaces */ +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_REJ_UNIFORM + +#define NTT_BOUND_NATIVE (6 * MLKEM_Q) +static INLINE void ntt_native(poly *data) +{ + ntt_asm_opt(data->coeffs, aarch64_ntt_zetas_layer01234, + aarch64_ntt_zetas_layer56); +} + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +static INLINE void intt_native(poly *data) +{ + intt_asm_opt(data->coeffs, aarch64_invntt_zetas_layer01234, + aarch64_invntt_zetas_layer56); +} + +static INLINE void poly_reduce_native(poly *data) +{ + poly_reduce_asm_opt(data->coeffs); +} +static INLINE void poly_tomont_native(poly *data) +{ + poly_tomont_asm_opt(data->coeffs); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + poly_mulcache_compute_asm_opt(x->coeffs, y->coeffs, + aarch64_zetas_mulcache_native, + aarch64_zetas_mulcache_twisted_native); +} +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_asm_opt( + r->coeffs, a->vec[0].coeffs, b->vec[0].coeffs, b_cache->vec[0].coeffs); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + poly_tobytes_asm_opt(r, a->coeffs); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + if (len != MLKEM_N || buflen % 24 != 0) + { + return -1; + } + return (int)rej_uniform_asm_clean(r, buf, buflen, rej_uniform_table); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/optimize.sh b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/optimize.sh new file mode 100755 index 0000000000..9d43dfa80d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/optimize.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env sh +# Copyright (c) 2024 The mlkem-native project authors +# SPDX-License-Identifier: Apache-2.0 + +set -e + +TARGET_NAME="Cortex-A55" +TARGET=Arm_Cortex_A55 + +echo "* polyvec_basemul_acc_montgomery_cached, K=2, ${TARGET_NAME}" + +cp polyvec_clean.S polyvec_opt.S + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k2_clean,polyvec_basemul_acc_montgomery_cached_asm_k2_opt \ + -l k2_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=3, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k3_clean,polyvec_basemul_acc_montgomery_cached_asm_k3_opt \ + -l k3_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c sw_pipelining.allow_post \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* polyvec_basemul_acc_montgomery_cached, K=4, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + polyvec_opt.S -o polyvec_opt.S \ + -r polyvec_basemul_acc_montgomery_cached_asm_k4_clean,polyvec_basemul_acc_montgomery_cached_asm_k4_opt \ + -l k4_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c sw_pipelining.allow_post \ + -c constraints.stalls_first_attempt=64 + +cp poly_clean.S poly_opt.S + +echo "* poly_reduce, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_reduce_asm_clean,poly_reduce_asm_opt \ + -l loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_mulcache_compute, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_mulcache_compute_asm_clean,poly_mulcache_compute_asm_opt \ + -l mulcache_compute_loop_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo "* poly_tomont, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + poly_opt.S -o poly_opt.S \ + -r poly_tomont_asm_clean,poly_tomont_asm_opt \ + -l poly_tomont_asm_loop \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp,v8--v15]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * ntt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + ntt_clean.S -o ntt_opt.S \ + -r ntt_asm_clean,ntt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 + +echo " * intt, ${TARGET_NAME}" + +slothy-cli Arm_AArch64 $TARGET \ + intt_clean.S -o intt_opt.S \ + -r intt_asm_clean,intt_asm_opt \ + -l layer123_start \ + -l layer4567_start \ + -c sw_pipelining.enabled=true \ + -c inputs_are_outputs \ + -c reserved_regs="[x18--x30,sp]" \ + -c sw_pipelining.minimize_overlapping=False \ + -c variable_size \ + -c constraints.stalls_first_attempt=64 diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_clean.S new file mode 100644 index 0000000000..f70a402215 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_clean.S @@ -0,0 +1,331 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 +loop_start: + ldr q_data, [ptr], #64 + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-64] + + ldr q_data, [ptr, #-48] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-48] + + ldr q_data, [ptr, #-32] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-32] + + ldr q_data, [ptr, #-16] + barrett_reduce data + scalar_signed_to_unsigned data + str q_data, [ptr, #-16] + + subs count, count, #1 + cbnz count, loop_start + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_clean): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 +mulcache_compute_loop_start: + ldr q_tmp0, [data_ptr], #32 + ldr q_tmp1, [data_ptr, #-16] + ldr q_zeta, [zeta_ptr], #16 + ldr q_zeta_twisted, [zeta_twisted_ptr], #16 + + // The mulcache of a polynomial a + b*X in Fq[X^2-zeta] is b*zeta; + // Since tmp0 || tmp1 represents multiple such polynomails as + // (a0,b0,a1,b1,...), extract only the odd elements. + uzp2 data_odd.8h, tmp0.8h, tmp1.8h + mulmod dst, data_odd, zeta, zeta_twisted + + str q_dst, [cache_ptr], #16 + + subs count, count, #1 + cbnz count, mulcache_compute_loop_start + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_clean): + + mov count, #16 +poly_tobytes_asm_clean_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_clean_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_clean): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 +poly_tomont_asm_loop: + + ldr q_data, [src], #64 + mulmod res, data, factor, factor_t + str q_res, [src, #-64] + + ldr q_data, [src, #-48] + mulmod res, data, factor, factor_t + str q_res, [src, #-48] + + ldr q_data, [src, #-32] + mulmod res, data, factor, factor_t + str q_res, [src, #-32] + + ldr q_data, [src, #-16] + mulmod res, data, factor, factor_t + str q_res, [src, #-16] + + sub count, count, #1 + cbnz count, poly_tomont_asm_loop + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_opt.S new file mode 100644 index 0000000000..e58ee77c46 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/poly_opt.S @@ -0,0 +1,690 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 20159 // Barrett twist of 1 wrt 2^27 +c_mont_constant: dup8h -1044 // 2^16 % 3329 +c_barrett_twist: dup8h -10276 // Barrett twist of -1044 (wrt 2^16) + +/* + * Some modular arithmetic macros + */ + +/* Barrett reduction */ +.macro barrett_reduce a + sqdmulh tmp.8h, \a\().8h, modulus_twisted.h[0] + srshr tmp.8h, tmp.8h, #11 + mls \a\().8h, tmp.8h, modulus.h[0] +.endm + +/* Montgomery multiplication, with precomputed Montgomery twist + * Expects modulus in consts.h[0]. */ +.macro mulmod dst, src, const, const_twisted + sqrdmulh tmp0.8h, \src\().8h, \const_twisted\().8h + mul \dst\().8h, \src\().8h, \const\().8h + mls \dst\().8h, tmp0.8h, modulus.h[0] +.endm + +/* Turns signed-canonical to unsigned canonical representative + * through conditional addition of the modulus. + * + * Expected modulus in `modulus`. */ +.macro scalar_signed_to_unsigned a + sshr mask.8h, \a\().8h, #15 + and mask.16b, modulus.16b, mask.16b + add \a\().8h, \a\().8h, mask.8h +.endm + +/********************************** + * poly_reduce() * + **********************************/ + +.global MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt) + + ptr .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + + tmp .req v1 + mask .req v2 + modulus .req v3 + q_modulus .req q3 + modulus_twisted .req v4 + q_modulus_twisted .req q4 + +MLKEM_ASM_NAMESPACE(poly_reduce_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #8 + // Instructions: 15 + // Expected cycles: 22 + // Expected IPC: 0.68 + + // Cycle bound: 22.0 + // IPC bound: 0.68 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q21, [x0, #32] // *............................. + ldr q23, [x0, #48] // ..*........................... + sqdmulh v7.8H, v21.8H, v4.H[0] // ....*......................... + sqdmulh v30.8H, v23.8H, v4.H[0] // ......*....................... + srshr v7.8H, v7.8H, #11 // ........*..................... + srshr v30.8H, v30.8H, #11 // ..........*................... + mls v21.8H, v7.8H, v3.H[0] // ...........*.................. + mls v23.8H, v30.8H, v3.H[0] // .............*................ + ldr q5, [x0, #16] // ..............*............... + sshr v7.8H, v21.8H, #15 // ................*............. + sshr v30.8H, v23.8H, #15 // .................*............ + and v7.16B, v3.16B, v7.16B // ..................*........... + add v21.8H, v21.8H, v7.8H // ...................*.......... + and v7.16B, v3.16B, v30.16B // ....................*......... + add v16.8H, v23.8H, v7.8H // .....................*........ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q30, [x0, #32] // *.............................. + // sqdmulh v22.8H, v30.8H, v4.H[0] // ....*.......................... + // ldr q2, [x0, #48] // ..*............................ + // srshr v19.8H, v22.8H, #11 // ........*...................... + // mls v30.8H, v19.8H, v3.H[0] // ...........*................... + // sqdmulh v25.8H, v2.8H, v4.H[0] // ......*........................ + // sshr v31.8H, v30.8H, #15 // ................*.............. + // srshr v25.8H, v25.8H, #11 // ..........*.................... + // and v18.16B, v3.16B, v31.16B // ..................*............ + // mls v2.8H, v25.8H, v3.H[0] // .............*................. + // add v21.8H, v30.8H, v18.8H // ...................*........... + // ldr q5, [x0, #16] // ..............*................ + // sshr v18.8H, v2.8H, #15 // .................*............. + // and v27.16B, v3.16B, v18.16B // ....................*.......... + // add v16.8H, v2.8H, v27.8H // .....................*......... + + sub count, count, #1 +1: + // Instructions: 32 + // Expected cycles: 36 + // Expected IPC: 0.89 + + // Cycle bound: 36.0 + // IPC bound: 0.89 + + // Wall time: 1.05s + // User time: 1.05s + + // -------- cycle (expected) ---------> + // 0 25 + // |------------------------|---------- + ldr q6, [x0], #64 // *................................... + ldr q30, [x0, #32] // ..e................................. + sqdmulh v31.8H, v6.8H, v4.H[0] // ....*............................... + sqdmulh v29.8H, v5.8H, v4.H[0] // .....*.............................. + sqdmulh v22.8H, v30.8H, v4.H[0] // ......e............................. + str q16, [x0, #-16] // .......*............................ + srshr v20.8H, v31.8H, #11 // ........*........................... + srshr v28.8H, v29.8H, #11 // .........*.......................... + str q21, [x0, #-32] // ..........*......................... + mls v6.8H, v20.8H, v3.H[0] // ...........*........................ + mls v5.8H, v28.8H, v3.H[0] // ............*....................... + ldr q2, [x0, #48] // .............e...................... + sshr v31.8H, v6.8H, #15 // ...............*.................... + srshr v19.8H, v22.8H, #11 // ................e................... + and v22.16B, v3.16B, v31.16B // .................*.................. + add v0.8H, v6.8H, v22.8H // ..................*................. + mls v30.8H, v19.8H, v3.H[0] // ...................e................ + sshr v26.8H, v5.8H, #15 // ....................*............... + sqdmulh v25.8H, v2.8H, v4.H[0] // .....................e.............. + and v17.16B, v3.16B, v26.16B // ......................*............. + add v1.8H, v5.8H, v17.8H // .......................*............ + sshr v31.8H, v30.8H, #15 // ........................e........... + srshr v25.8H, v25.8H, #11 // .........................e.......... + str q1, [x0, #-48] // ..........................*......... + and v18.16B, v3.16B, v31.16B // ...........................e........ + mls v2.8H, v25.8H, v3.H[0] // ............................e....... + add v21.8H, v30.8H, v18.8H // .............................e...... + ldr q5, [x0, #16] // ..............................e..... + sshr v18.8H, v2.8H, #15 // ................................e... + str q0, [x0, #-64] // .................................*.. + and v27.16B, v3.16B, v18.16B // ..................................e. + add v16.8H, v2.8H, v27.8H // ...................................e + + // ------------------------ cycle (expected) -------------------------> + // 0 25 50 + // |------------------------|------------------------|----------------- + // ldr q0, [x0], #64 // ..................................*................................. + // sqdmulh v1.8h, v0.8h, v4.h[0] // ..~...............................'...*............................. + // srshr v1.8h, v1.8h, #11 // ......~...........................'.......*......................... + // mls v0.8h, v1.8h, v3.h[0] // .........~........................'..........*...................... + // sshr v2.8h, v0.8h, #15 // .............~....................'..............*.................. + // and v2.16b, v3.16b, v2.16b // ...............~..................'................*................ + // add v0.8h, v0.8h, v2.8h // ................~.................'.................*............... + // str q0, [x0, #-64] // ...............................~..'................................* + // ldr q0, [x0, #-48] // ............................e.....'.............................~... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...~..............................'....*............................ + // srshr v1.8h, v1.8h, #11 // .......~..........................'........*........................ + // mls v0.8h, v1.8h, v3.h[0] // ..........~.......................'...........*..................... + // sshr v2.8h, v0.8h, #15 // ..................~...............'...................*............. + // and v2.16b, v3.16b, v2.16b // ....................~.............'.....................*........... + // add v0.8h, v0.8h, v2.8h // .....................~............'......................*.......... + // str q0, [x0, #-48] // ........................~.........'.........................*....... + // ldr q0, [x0, #-32] // e.................................'.~............................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ....e.............................'.....~........................... + // srshr v1.8h, v1.8h, #11 // ..............e...................'...............~................. + // mls v0.8h, v1.8h, v3.h[0] // .................e................'..................~.............. + // sshr v2.8h, v0.8h, #15 // ......................e...........'.......................~......... + // and v2.16b, v3.16b, v2.16b // .........................e........'..........................~...... + // add v0.8h, v0.8h, v2.8h // ...........................e......'............................~.... + // str q0, [x0, #-32] // ........~.........................'.........*....................... + // ldr q0, [x0, #-16] // ...........e......................'............~.................... + // sqdmulh v1.8h, v0.8h, v4.h[0] // ...................e..............'....................~............ + // srshr v1.8h, v1.8h, #11 // .......................e..........'........................~........ + // mls v0.8h, v1.8h, v3.h[0] // ..........................e.......'...........................~..... + // sshr v2.8h, v0.8h, #15 // ..............................e...'...............................~. + // and v2.16b, v3.16b, v2.16b // ................................e.'................................. + // add v0.8h, v0.8h, v2.8h // .................................e'................................. + // str q0, [x0, #-16] // .....~............................'......*.......................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 17 + // Expected cycles: 23 + // Expected IPC: 0.74 + + // Cycle bound: 23.0 + // IPC bound: 0.74 + + // Wall time: 0.05s + // User time: 0.05s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + sqdmulh v20.8H, v5.8H, v4.H[0] // *............................. + ldr q24, [x0], #64 // .*............................ + str q21, [x0, #-32] // ...*.......................... + srshr v20.8H, v20.8H, #11 // ....*......................... + sqdmulh v25.8H, v24.8H, v4.H[0] // .....*........................ + str q16, [x0, #-16] // ......*....................... + mls v5.8H, v20.8H, v3.H[0] // .......*...................... + srshr v20.8H, v25.8H, #11 // .........*.................... + sshr v2.8H, v5.8H, #15 // ...........*.................. + mls v24.8H, v20.8H, v3.H[0] // ............*................. + and v20.16B, v3.16B, v2.16B // .............*................ + add v31.8H, v5.8H, v20.8H // ..............*............... + sshr v20.8H, v24.8H, #15 // ................*............. + str q31, [x0, #-48] // .................*............ + and v31.16B, v3.16B, v20.16B // ..................*........... + add v24.8H, v24.8H, v31.8H // ...................*.......... + str q24, [x0, #-64] // ......................*....... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q6, [x0], #64 // .*............................. + // sqdmulh v31.8H, v6.8H, v4.H[0] // .....*......................... + // sqdmulh v29.8H, v5.8H, v4.H[0] // *.............................. + // str q16, [x0, #-16] // ......*........................ + // srshr v20.8H, v31.8H, #11 // .........*..................... + // srshr v28.8H, v29.8H, #11 // ....*.......................... + // str q21, [x0, #-32] // ...*........................... + // mls v6.8H, v20.8H, v3.H[0] // ............*.................. + // mls v5.8H, v28.8H, v3.H[0] // .......*....................... + // sshr v31.8H, v6.8H, #15 // ................*.............. + // and v22.16B, v3.16B, v31.16B // ..................*............ + // add v0.8H, v6.8H, v22.8H // ...................*........... + // sshr v26.8H, v5.8H, #15 // ...........*................... + // and v17.16B, v3.16B, v26.16B // .............*................. + // add v1.8H, v5.8H, v17.8H // ..............*................ + // str q1, [x0, #-48] // .................*............. + // str q0, [x0, #-64] // ......................*........ + + + ret + + .unreq ptr + .unreq count + + .unreq data + .unreq q_data + + .unreq tmp + .unreq mask + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_mulcache_compute() * + ********************************************/ + +.global MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt) + + cache_ptr .req x0 + data_ptr .req x1 + zeta_ptr .req x2 + zeta_twisted_ptr .req x3 + count .req x4 + + data_odd .req v0 + zeta .req v1 + q_zeta .req q1 + zeta_twisted .req v2 + q_zeta_twisted .req q2 + + tmp0 .req v3 + q_tmp0 .req q3 + tmp1 .req v4 + q_tmp1 .req q4 + dst .req v5 + q_dst .req q5 + + modulus .req v6 + q_modulus .req q6 + modulus_twisted .req v7 + q_modulus_twisted .req q7 + +MLKEM_ASM_NAMESPACE(poly_mulcache_compute_asm_opt): + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + mov count, #16 + // Instructions: 7 + // Expected cycles: 12 + // Expected IPC: 0.58 + + // Cycle bound: 12.0 + // IPC bound: 0.58 + + // Wall time: 0.01s + // User time: 0.01s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q1, [x1, #16] // *............................. + ldr q27, [x1], #32 // ..*........................... + ldr q23, [x2], #16 // ....*......................... + uzp2 v27.8H, v27.8H, v1.8H // ......*....................... + ldr q1, [x3], #16 // .......*...................... + mul v2.8H, v27.8H, v23.8H // .........*.................... + sqrdmulh v27.8H, v27.8H, v1.8H // ...........*.................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q29, [x1, #16] // *.............................. + // ldr q21, [x2], #16 // ....*.......................... + // ldr q27, [x1], #32 // ..*............................ + // ldr q7, [x3], #16 // .......*....................... + // uzp2 v28.8H, v27.8H, v29.8H // ......*........................ + // mul v2.8H, v28.8H, v21.8H // .........*..................... + // sqrdmulh v27.8H, v28.8H, v7.8H // ...........*................... + + sub count, count, #1 +1: + // Instructions: 9 + // Expected cycles: 13 + // Expected IPC: 0.69 + + // Cycle bound: 13.0 + // IPC bound: 0.69 + + // Wall time: 0.09s + // User time: 0.09s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q29, [x1, #16] // e............................. + ldr q21, [x2], #16 // ..e........................... + mls v2.8H, v27.8H, v6.H[0] // ....*......................... + ldr q27, [x1], #32 // .....e........................ + ldr q7, [x3], #16 // .......e...................... + uzp2 v28.8H, v27.8H, v29.8H // .........e.................... + str q2, [x0], #16 // ..........*................... + mul v2.8H, v28.8H, v21.8H // ...........e.................. + sqrdmulh v27.8H, v28.8H, v7.8H // ............e................. + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q3, [x1], #32 // .....e.......'....~.......'.... + // ldr q4, [x1, #-16] // e............~............~.... + // ldr q1, [x2], #16 // ..e..........'.~..........'.~.. + // ldr q2, [x3], #16 // .......e.....'......~.....'.... + // uzp2 v0.8h, v3.8h, v4.8h // .........e...'........~...'.... + // sqrdmulh v3.8h, v0.8h, v2.8h // ............e'...........~'.... + // mul v5.8h, v0.8h, v1.8h // ...........e.'..........~.'.... + // mls v5.8h, v3.8h, v6.h[0] // ....~........'...*........'.... + // str q5, [x0], #16 // ..........~..'.........*..'.... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 2 + // Expected cycles: 5 + // Expected IPC: 0.40 + + // Cycle bound: 5.0 + // IPC bound: 0.40 + + // Wall time: 0.00s + // User time: 0.00s + + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v2.8H, v27.8H, v6.H[0] // *............................. + str q2, [x0], #16 // ....*......................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v2.8H, v27.8H, v6.H[0] // *.............................. + // str q2, [x0], #16 // ....*.......................... + + + ret + + .unreq cache_ptr + .unreq data_ptr + .unreq zeta_ptr + .unreq zeta_twisted_ptr + .unreq count + + .unreq data_odd + .unreq zeta + .unreq q_zeta + .unreq zeta_twisted + .unreq q_zeta_twisted + + .unreq tmp0 + .unreq q_tmp0 + .unreq tmp1 + .unreq q_tmp1 + .unreq dst + .unreq q_dst + + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + +/******************************************** + * poly_tobytes() * + ********************************************/ +.global MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt) + + data0 .req v0 + data1 .req v1 + out0 .req v2 + out1 .req v3 + out2 .req v4 + tmp .req v5 + + dst .req x0 + src .req x1 + count .req x2 + +MLKEM_ASM_NAMESPACE(poly_tobytes_asm_opt): + + mov count, #16 +poly_tobytes_asm_opt_asm_loop_start: + ld2 {data0.8h, data1.8h}, [src], #32 + + // r[3 * i + 0] = (t0 >> 0); + xtn out0.8b, data0.8h + + // r[3 * i + 1] = (t0 >> 8); + shrn out1.8b, data0.8h, #8 + xtn tmp.8b, data1.8h + // r[3 * i + 1] = (t0 >> 8) | (t1 << 4); + sli out1.8b, tmp.8b, #4 + + // r[3 * i + 2] = (t1 >> 4); + shrn out2.8b, data1.8h, #4 + + st3 {out0.8b, out1.8b, out2.8b}, [dst], #24 + + subs count, count, #1 + cbnz count, poly_tobytes_asm_opt_asm_loop_start + ret + + .unreq data0 + .unreq data1 + .unreq out0 + .unreq out1 + .unreq out2 + .unreq tmp + .unreq dst + .unreq src + .unreq count + +/********************************** + * poly_tomont() * + **********************************/ +.global MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt) + + src .req x0 + count .req x1 + + data .req v0 + q_data .req q0 + res .req v1 + q_res .req q1 + + factor .req v2 + q_factor .req q2 + factor_t .req v3 + q_factor_t .req q3 + modulus .req v4 + q_modulus .req q4 + modulus_twisted .req v5 + q_modulus_twisted .req q5 + + tmp0 .req v6 + +MLKEM_ASM_NAMESPACE(poly_tomont_asm_opt): + + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + ldr q_factor, c_mont_constant + ldr q_factor_t, c_barrett_twist + + mov count, #8 + // Instructions: 5 + // Expected cycles: 7 + // Expected IPC: 0.71 + // + // Cycle bound: 7.0 + // IPC bound: 0.71 + // + // Wall time: 0.01s + // User time: 0.01s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + ldr q26, [x0, #48] // *............................. + ldr q23, [x0, #16] // ..*........................... + mul v17.8H, v26.8H, v2.8H // ....*......................... + sqrdmulh v7.8H, v26.8H, v3.8H // .....*........................ + ldr q27, [x0, #32] // ......*....................... + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // ldr q7, [x0, #48] // *.............................. + // ldr q23, [x0, #16] // ..*............................ + // mul v17.8H, v7.8H, v2.8H // ....*.......................... + // sqrdmulh v7.8H, v7.8H, v3.8H // .....*......................... + // ldr q27, [x0, #32] // ......*........................ + + sub count, count, #1 +1: + // Instructions: 20 + // Expected cycles: 24 + // Expected IPC: 0.83 + // + // Cycle bound: 24.0 + // IPC bound: 0.83 + // + // Wall time: 0.73s + // User time: 0.73s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v5.8H, v23.8H, v3.8H // .*............................ + ldr q7, [x0], #64 // ..*........................... + str q17, [x0, #-16] // ....*......................... + sqrdmulh v29.8H, v27.8H, v3.8H // .....*........................ + sqrdmulh v19.8H, v7.8H, v3.8H // ......*....................... + mul v25.8H, v23.8H, v2.8H // .......*...................... + mul v0.8H, v7.8H, v2.8H // ........*..................... + mul v26.8H, v27.8H, v2.8H // .........*.................... + ldr q7, [x0, #48] // ..........e................... + mls v25.8H, v5.8H, v4.H[0] // ............*................. + ldr q23, [x0, #16] // .............e................ + mls v26.8H, v29.8H, v4.H[0] // ...............*.............. + mls v0.8H, v19.8H, v4.H[0] // ................*............. + str q25, [x0, #-48] // .................*............ + mul v17.8H, v7.8H, v2.8H // ..................e........... + sqrdmulh v7.8H, v7.8H, v3.8H // ...................e.......... + str q0, [x0, #-64] // ....................*......... + ldr q27, [x0, #32] // .....................e........ + str q26, [x0, #-32] // .......................*...... + + // --------- cycle (expected) ----------> + // 0 25 + // |------------------------|------------ + // ldr q0, [x0], #64 // ..............'.*..................... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'.....*................. + // mul v1.8h, v0.8h, v2.8h // ..............'.......*............... + // mls v1.8h, v6.8h, v4.h[0] // ......~.......'...............*....... + // str q1, [x0, #-64] // ..........~...'...................*... + // ldr q0, [x0, #-48] // ...e..........'............~.......... + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'*...................... + // mul v1.8h, v0.8h, v2.8h // ..............'......*................ + // mls v1.8h, v6.8h, v4.h[0] // ..~...........'...........*........... + // str q1, [x0, #-48] // .......~......'................*...... + // ldr q0, [x0, #-32] // ...........e..'....................~.. + // sqrdmulh v6.8h, v0.8h, v3.8h // ..............'....*.................. + // mul v1.8h, v0.8h, v2.8h // ..............'........*.............. + // mls v1.8h, v6.8h, v4.h[0] // .....~........'..............*........ + // str q1, [x0, #-32] // .............~'......................* + // ldr q0, [x0, #-16] // e.............'.........~............. + // sqrdmulh v6.8h, v0.8h, v3.8h // .........e....'..................~.... + // mul v1.8h, v0.8h, v2.8h // ........e.....'.................~..... + // mls v1.8h, v6.8h, v4.h[0] // ..............*....................... + // str q1, [x0, #-16] // ..............'...*................... + + sub count, count, 1 + cbnz count, 1b + // Instructions: 15 + // Expected cycles: 18 + // Expected IPC: 0.83 + // + // Cycle bound: 18.0 + // IPC bound: 0.83 + // + // Wall time: 0.07s + // User time: 0.07s + // + // ----- cycle (expected) ------> + // 0 25 + // |------------------------|---- + mls v17.8H, v7.8H, v4.H[0] // *............................. + sqrdmulh v7.8H, v23.8H, v3.8H // .*............................ + mul v26.8H, v23.8H, v2.8H // ..*........................... + sqrdmulh v25.8H, v27.8H, v3.8H // ...*.......................... + ldr q23, [x0], #64 // ....*......................... + mul v27.8H, v27.8H, v2.8H // ......*....................... + mls v26.8H, v7.8H, v4.H[0] // .......*...................... + sqrdmulh v7.8H, v23.8H, v3.8H // ........*..................... + mul v23.8H, v23.8H, v2.8H // .........*.................... + str q17, [x0, #-16] // ..........*................... + mls v27.8H, v25.8H, v4.H[0] // ...........*.................. + str q26, [x0, #-48] // ............*................. + mls v23.8H, v7.8H, v4.H[0] // .............*................ + str q27, [x0, #-32] // ...............*.............. + str q23, [x0, #-64] // .................*............ + + // ------ cycle (expected) ------> + // 0 25 + // |------------------------|----- + // mls v17.8H, v7.8H, v4.H[0] // *.............................. + // sqrdmulh v5.8H, v23.8H, v3.8H // .*............................. + // ldr q7, [x0], #64 // ....*.......................... + // str q17, [x0, #-16] // ..........*.................... + // sqrdmulh v29.8H, v27.8H, v3.8H // ...*........................... + // sqrdmulh v19.8H, v7.8H, v3.8H // ........*...................... + // mul v25.8H, v23.8H, v2.8H // ..*............................ + // mul v0.8H, v7.8H, v2.8H // .........*..................... + // mul v26.8H, v27.8H, v2.8H // ......*........................ + // mls v25.8H, v5.8H, v4.H[0] // .......*....................... + // mls v26.8H, v29.8H, v4.H[0] // ...........*................... + // mls v0.8H, v19.8H, v4.H[0] // .............*................. + // str q25, [x0, #-48] // ............*.................. + // str q0, [x0, #-64] // .................*............. + // str q26, [x0, #-32] // ...............*............... + + + ret + + .unreq src + .unreq count + + .unreq data + .unreq q_data + .unreq res + .unreq q_res + + .unreq factor + .unreq q_factor + .unreq factor_t + .unreq q_factor_t + .unreq modulus + .unreq q_modulus + .unreq modulus_twisted + .unreq q_modulus_twisted + + .unreq tmp0 + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_clean.S new file mode 100644 index 0000000000..99fb05de5d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_clean.S @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +// +// AArch64 re-implementation of the asymmetric base multiplication from: +// +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. +// +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) +k2_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k2_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) +k3_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k3_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_clean): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + // + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) +k4_loop_start: + + load_polys aa, bb, a0_ptr, b0_ptr, b0_cache_ptr + pmull res, aa, bb + load_polys aa, bb, a1_ptr, b1_ptr, b1_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a2_ptr, b2_ptr, b2_cache_ptr + pmlal res, aa, bb + load_polys aa, bb, a3_ptr, b3_ptr, b3_cache_ptr + pmlal res, aa, bb + + montgomery_reduce_long out0, res0 + montgomery_reduce_long out1, res1 + + st2_wrap out, out + + subs count, count, #1 + cbnz count, k4_loop_start + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_opt.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_opt.S new file mode 100644 index 0000000000..16ed77c3fc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/polyvec_opt.S @@ -0,0 +1,1584 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// AArch64 re-implementation of the asymmetric base multiplication from: + +// Neon NTT: Faster Dilithium, Kyber, and Saber on Cortex-A72 and Apple M1 +// https://eprint.iacr.org/2021/986 +// https://github.com/neon-ntt/neon-ntt + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +/* We use a single literal pool for all functions in this file. + * This is OK even when the file gets expanded through SLOTHY, + * since PC-relative offets are up to 1MB in AArch64. + * + * The use of dup8h to build constant vectors in memory + * is slightly wasteful and could be avoided with a GPR-load + * followed by Neon `dup`, but we're ultimately only talking + * about 64 bytes, so it seems OK. + */ + +.macro dup8h c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c + .short \c +.endm + +.p2align 4 +c_modulus: dup8h 3329 // ML-KEM modulus +c_modulus_twisted: dup8h 3327 + +// Input: +// - Vectors al, ah of 32-bit entries +// Output: +// - Montgomery reductions of al || ah, stored in al +.macro montgomery_reduce_long x, a + uzp1 t0.8h, \a\()l.8h, \a\()h.8h + mul t0.8h, t0.8h, modulus_twisted.8h + smlal \a\()l.4s, t0.4h, modulus.4h + smlal2 \a\()h.4s, t0.8h, modulus.8h + uzp2 \x\().8h, \a\()l.8h, \a\()h.8h +.endm + +// Computes products (a0*b0 + a0*b0t, a0*b1 + a1*b0) in 32-bit. + +// Bounds: +// - Assume |a| < 4096, +// - Result: < 2*4096*2^15 = 2^28 +.macro pmull d, a, b + smull \d\()0l.4s, \a\()0.4h, \b\()0.4h + smull2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smull \d\()1l.4s, \a\()0.4h, \b\()1.4h + smull2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro pmlal d, a, b + smlal \d\()0l.4s, \a\()0.4h, \b\()0.4h + smlal2 \d\()0h.4s, \a\()0.8h, \b\()0.8h + smlal \d\()0l.4s, \a\()1.4h, \b\()1t.4h + smlal2 \d\()0h.4s, \a\()1.8h, \b\()1t.8h + + smlal \d\()1l.4s, \a\()0.4h, \b\()1.4h + smlal2 \d\()1h.4s, \a\()0.8h, \b\()1.8h + smlal \d\()1l.4s, \a\()1.4h, \b\()0.4h + smlal2 \d\()1h.4s, \a\()1.8h, \b\()0.8h +.endm + +.macro ld2_wrap a, ptr + ldr q_tmp0, [\ptr\()], #32 + ldr q_tmp1, [\ptr\(), #-16] + uzp1 \a\()0.8h, tmp0.8h, tmp1.8h + uzp2 \a\()1.8h, tmp0.8h, tmp1.8h +.endm + +.macro st2_wrap a, ptr + zip1 tmp0.8h, \a\()0.8h, \a\()1.8h + zip2 tmp1.8h, \a\()0.8h, \a\()1.8h + str q_tmp0, [\ptr\()], #32 + str q_tmp1, [\ptr\(), #-16] +.endm + +.macro load_polys a, b, a_ptr, b_ptr, b_cache_ptr + ld2_wrap \a\(), \a_ptr + ld2_wrap \b\(), \b_ptr + ld1 {\b\()1t.8h}, [\b_cache_ptr], #16 +.endm + +.macro save_vregs + sub sp, sp, #(16*4) + stp d8, d9, [sp, #16*0] + stp d10, d11, [sp, #16*1] + stp d12, d13, [sp, #16*2] + stp d14, d15, [sp, #16*3] +.endm + +.macro restore_vregs + ldp d8, d9, [sp, #16*0] + ldp d10, d11, [sp, #16*1] + ldp d12, d13, [sp, #16*2] + ldp d14, d15, [sp, #16*3] + add sp, sp, #(16*4) +.endm + +.macro push_stack + save_vregs +.endm + +.macro pop_stack + restore_vregs +.endm + + out .req x0 + a0_ptr .req x1 + b0_ptr .req x2 + b0_cache_ptr .req x3 + a1_ptr .req x4 + b1_ptr .req x5 + b1_cache_ptr .req x6 + a2_ptr .req x7 + b2_ptr .req x8 + b2_cache_ptr .req x9 + a3_ptr .req x10 + b3_ptr .req x11 + b3_cache_ptr .req x12 + count .req x13 + + modulus .req v0 + q_modulus .req q0 + modulus_twisted .req v2 + q_modulus_twisted .req q2 + + aa0 .req v3 + aa1 .req v4 + bb0 .req v5 + bb1 .req v6 + bb1t .req v7 + + res0l .req v8 + res1l .req v9 + res0h .req v10 + res1h .req v11 + + tmp0 .req v12 + tmp1 .req v13 + q_tmp0 .req q12 + q_tmp1 .req q13 + + out0 .req v26 + out1 .req v27 + + t0 .req v28 + +#if MLKEM_K == 2 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 94 + // Expected IPC: 0.80 + + // Cycle bound: 94.0 + // IPC bound: 0.80 + + // Wall time: 1.49s + // User time: 1.49s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q9, [x4], #32 // *.......................................................................... + ldr q5, [x4, #-16] // ......*.................................................................... + ldr q11, [x5], #32 // .*......................................................................... + uzp1 v23.8H, v9.8H, v5.8H // .........*................................................................. + uzp2 v9.8H, v9.8H, v5.8H // .....................*..................................................... + ldr q5, [x2], #32 // ..*........................................................................ + ldr q7, [x5, #-16] // ..............*............................................................ + ldr q21, [x2, #-16] // ...*....................................................................... + uzp2 v10.8H, v11.8H, v7.8H // .................*......................................................... + uzp1 v11.8H, v11.8H, v7.8H // ..................*........................................................ + uzp1 v7.8H, v5.8H, v21.8H // ....*...................................................................... + uzp2 v5.8H, v5.8H, v21.8H // .....*..................................................................... + ldr q21, [x1], #32 // .......*................................................................... + ldr q25, [x1, #-16] // ........*.................................................................. + ld1 {v6.8H}, [x3], #16 // ............................*.............................................. + uzp1 v26.8H, v21.8H, v25.8H // ..........*................................................................ + uzp2 v21.8H, v21.8H, v25.8H // ...........*............................................................... + smull v25.4S, v26.4H, v5.4H // ............*.............................................................. + smull2 v5.4S, v26.8H, v5.8H // .............*............................................................. + smull v19.4S, v26.4H, v7.4H // ..........................*................................................ + smull2 v26.4S, v26.8H, v7.8H // ..............................*............................................ + smlal v25.4S, v21.4H, v7.4H // ...............*........................................................... + smlal2 v5.4S, v21.8H, v7.8H // ................*.......................................................... + smlal v19.4S, v21.4H, v6.4H // ...................................*....................................... + smlal2 v26.4S, v21.8H, v6.8H // .................................*......................................... + smlal v25.4S, v23.4H, v10.4H // ...................*....................................................... + smlal2 v5.4S, v23.8H, v10.8H // ....................*...................................................... + smlal v19.4S, v23.4H, v11.4H // ......................................*.................................... + smlal2 v26.4S, v23.8H, v11.8H // ....................................*...................................... + ld1 {v23.8H}, [x6], #16 // ........................*.................................................. + smlal v25.4S, v9.4H, v11.4H // ......................*.................................................... + smlal2 v5.4S, v9.8H, v11.8H // .......................*................................................... + smlal2 v26.4S, v9.8H, v23.8H // .......................................*................................... + smlal v19.4S, v9.4H, v23.4H // .........................................*................................. + ldr q9, [x4], #32 // ...............................*........................................... + uzp1 v11.8H, v25.8H, v5.8H // .........................*................................................. + uzp1 v23.8H, v19.8H, v26.8H // .............................................*............................. + mul v11.8H, v11.8H, v2.8H // ...........................*............................................... + mul v23.8H, v23.8H, v2.8H // ..............................................*............................ + ldr q7, [x5], #32 // ................................*.......................................... + smlal2 v5.4S, v11.8H, v0.8H // .............................*............................................. + smlal v25.4S, v11.4H, v0.4H // ..................................*........................................ + ldr q11, [x2], #32 // .....................................*..................................... + ldr q21, [x2, #-16] // ........................................*.................................. + ldr q6, [x4, #-16] // ...............................................*........................... + uzp1 v17.8H, v11.8H, v21.8H // ...........................................*............................... + ldr q10, [x1], #32 // ................................................*.......................... + ldr q29, [x1, #-16] // .................................................*......................... + uzp2 v11.8H, v11.8H, v21.8H // ............................................*.............................. + uzp1 v13.8H, v9.8H, v6.8H // ...................................................*....................... + uzp1 v3.8H, v10.8H, v29.8H // ....................................................*...................... + uzp2 v10.8H, v10.8H, v29.8H // .....................................................*..................... + smull v12.4S, v3.4H, v11.4H // ......................................................*.................... + smull2 v11.4S, v3.8H, v11.8H // .......................................................*................... + ldr q21, [x5, #-16] // ........................................................*.................. + smlal v12.4S, v10.4H, v17.4H // .........................................................*................. + smlal2 v11.4S, v10.8H, v17.8H // ..........................................................*................ + uzp2 v29.8H, v7.8H, v21.8H // ...........................................................*............... + uzp1 v15.8H, v7.8H, v21.8H // ............................................................*.............. + smlal v12.4S, v13.4H, v29.4H // .............................................................*............. + smlal2 v11.4S, v13.8H, v29.8H // ..............................................................*............ + uzp2 v28.8H, v9.8H, v6.8H // ...............................................................*........... + smlal2 v26.4S, v23.8H, v0.8H // ..................................................*........................ + smlal v12.4S, v28.4H, v15.4H // .................................................................*......... + smlal2 v11.4S, v28.8H, v15.8H // ..................................................................*........ + smlal v19.4S, v23.4H, v0.4H // ................................................................*.......... + uzp2 v27.8H, v25.8H, v5.8H // ..........................................*................................ + smull v23.4S, v3.4H, v17.4H // ......................................................................*.... + uzp1 v9.8H, v12.8H, v11.8H // .....................................................................*..... + uzp2 v19.8H, v19.8H, v26.8H // ....................................................................*...... + mul v14.8H, v9.8H, v2.8H // .......................................................................*... + ld1 {v22.8H}, [x6], #16 // ...................................................................*....... + zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + smlal2 v11.4S, v14.8H, v0.8H // ..........................................................................* + ld1 {v4.8H}, [x3], #16 // .........................................................................*. + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q18, [x4], #32 // *.......................................................................... + // ldr q30, [x5], #32 // ..*........................................................................ + // ldr q8, [x2], #32 // .....*..................................................................... + // ldr q9, [x2, #-16] // .......*................................................................... + // uzp1 v17.8H, v8.8H, v9.8H // ..........*................................................................ + // uzp2 v4.8H, v8.8H, v9.8H // ...........*............................................................... + // ldr q19, [x4, #-16] // .*......................................................................... + // ldr q29, [x1], #32 // ............*.............................................................. + // ldr q12, [x1, #-16] // .............*............................................................. + // uzp1 v13.8H, v18.8H, v19.8H // ...*....................................................................... + // uzp1 v3.8H, v29.8H, v12.8H // ...............*........................................................... + // uzp2 v10.8H, v29.8H, v12.8H // ................*.......................................................... + // smull v12.4S, v3.4H, v4.4H // .................*......................................................... + // smull2 v11.4S, v3.8H, v4.8H // ..................*........................................................ + // ldr q5, [x5, #-16] // ......*.................................................................... + // smlal v12.4S, v10.4H, v17.4H // .....................*..................................................... + // smlal2 v11.4S, v10.8H, v17.8H // ......................*.................................................... + // uzp2 v14.8H, v30.8H, v5.8H // ........*.................................................................. + // uzp1 v15.8H, v30.8H, v5.8H // .........*................................................................. + // smlal v12.4S, v13.4H, v14.4H // .........................*................................................. + // smlal2 v11.4S, v13.8H, v14.8H // ..........................*................................................ + // uzp2 v28.8H, v18.8H, v19.8H // ....*...................................................................... + // smlal v12.4S, v28.4H, v15.4H // ..............................*............................................ + // smlal2 v11.4S, v28.8H, v15.8H // ...............................*........................................... + // ld1 {v22.8H}, [x6], #16 // .............................*............................................. + // uzp1 v1.8H, v12.8H, v11.8H // ...................................*....................................... + // smull v23.4S, v3.4H, v17.4H // ...................*....................................................... + // mul v14.8H, v1.8H, v2.8H // .....................................*..................................... + // ld1 {v4.8H}, [x3], #16 // ..............*............................................................ + // smlal2 v11.4S, v14.8H, v0.8H // ........................................*.................................. + // smull2 v20.4S, v3.8H, v17.8H // ....................*...................................................... + // ldr q18, [x4], #32 // ..................................*........................................ + // ldr q30, [x5], #32 // .......................................*................................... + // smlal2 v20.4S, v10.8H, v4.8H // ........................*.................................................. + // smlal v12.4S, v14.4H, v0.4H // .........................................*................................. + // smlal v23.4S, v10.4H, v4.4H // .......................*................................................... + // smlal2 v20.4S, v13.8H, v15.8H // ............................*.............................................. + // ldr q8, [x2], #32 // ..........................................*................................ + // smlal v23.4S, v13.4H, v15.4H // ...........................*............................................... + // smlal2 v20.4S, v28.8H, v22.8H // ................................*.......................................... + // ldr q9, [x2, #-16] // ...........................................*............................... + // smlal v23.4S, v28.4H, v22.4H // .................................*......................................... + // uzp2 v27.8H, v12.8H, v11.8H // ..................................................................*........ + // uzp1 v17.8H, v8.8H, v9.8H // .............................................*............................. + // uzp2 v4.8H, v8.8H, v9.8H // ................................................*.......................... + // uzp1 v5.8H, v23.8H, v20.8H // ....................................*...................................... + // mul v31.8H, v5.8H, v2.8H // ......................................*.................................... + // ldr q19, [x4, #-16] // ............................................*.............................. + // ldr q29, [x1], #32 // ..............................................*............................ + // ldr q12, [x1, #-16] // ...............................................*........................... + // smlal2 v20.4S, v31.8H, v0.8H // ..............................................................*............ + // uzp1 v13.8H, v18.8H, v19.8H // .................................................*......................... + // uzp1 v3.8H, v29.8H, v12.8H // ..................................................*........................ + // uzp2 v10.8H, v29.8H, v12.8H // ...................................................*....................... + // smull v12.4S, v3.4H, v4.4H // ....................................................*...................... + // smull2 v11.4S, v3.8H, v4.8H // .....................................................*..................... + // ldr q5, [x5, #-16] // ......................................................*.................... + // smlal v12.4S, v10.4H, v17.4H // .......................................................*................... + // smlal2 v11.4S, v10.8H, v17.8H // ........................................................*.................. + // uzp2 v14.8H, v30.8H, v5.8H // .........................................................*................. + // uzp1 v15.8H, v30.8H, v5.8H // ..........................................................*................ + // smlal v12.4S, v13.4H, v14.4H // ...........................................................*............... + // smlal2 v11.4S, v13.8H, v14.8H // ............................................................*.............. + // uzp2 v28.8H, v18.8H, v19.8H // .............................................................*............. + // smlal v23.4S, v31.4H, v0.4H // .................................................................*......... + // smlal v12.4S, v28.4H, v15.4H // ...............................................................*........... + // smlal2 v11.4S, v28.8H, v15.8H // ................................................................*.......... + // ld1 {v22.8H}, [x6], #16 // .......................................................................*... + // uzp2 v19.8H, v23.8H, v20.8H // .....................................................................*..... + // uzp1 v1.8H, v12.8H, v11.8H // ....................................................................*...... + // smull v23.4S, v3.4H, v17.4H // ...................................................................*....... + // mul v14.8H, v1.8H, v2.8H // ......................................................................*.... + // zip2 v9.8H, v19.8H, v27.8H // ........................................................................*.. + // ld1 {v4.8H}, [x3], #16 // ..........................................................................* + // smlal2 v11.4S, v14.8H, v0.8H // .........................................................................*. + + sub count, count, #2 +1: + // Instructions: 48 + // Expected cycles: 58 + // Expected IPC: 0.83 + + // Cycle bound: 58.0 + // IPC bound: 0.83 + + // Wall time: 6.39s + // User time: 6.39s + + // -------------- original position --------------> + // 0 25 + // |------------------------|---------------------- + smull2 v20.4S, v3.8H, v17.8H // ..........*..................................... + ldr q18, [x4], #32 // .................e.............................. + ldr q30, [x5], #32 // .....................e.......................... + smlal2 v20.4S, v10.8H, v4.8H // ............*................................... + smlal v12.4S, v14.4H, v0.4H // .........................................*...... + smlal v23.4S, v10.4H, v4.4H // ...........*.................................... + str q9, [x0, #16] // ...............................................l + smlal2 v20.4S, v13.8H, v15.8H // ...........................*.................... + ldr q8, [x2], #32 // ....e........................................... + smlal v23.4S, v13.4H, v15.4H // ..........................*..................... + smlal2 v20.4S, v28.8H, v22.8H // .............................*.................. + zip1 v26.8H, v19.8H, v27.8H // ............................................l... + ldr q9, [x2, #-16] // .....e.......................................... + smlal v23.4S, v28.4H, v22.4H // ............................*................... + uzp2 v27.8H, v12.8H, v11.8H // ...........................................*.... + uzp1 v17.8H, v8.8H, v9.8H // ......e......................................... + uzp2 v4.8H, v8.8H, v9.8H // .......e........................................ + uzp1 v5.8H, v23.8H, v20.8H // ..................................*............. + str q26, [x0], #32 // ..............................................l. + mul v31.8H, v5.8H, v2.8H // ...................................*............ + ldr q19, [x4, #-16] // ..................e............................. + ldr q29, [x1], #32 // e............................................... + ldr q12, [x1, #-16] // .e.............................................. + smlal2 v20.4S, v31.8H, v0.8H // .....................................*.......... + uzp1 v13.8H, v18.8H, v19.8H // ...................e............................ + uzp1 v3.8H, v29.8H, v12.8H // ..e............................................. + uzp2 v10.8H, v29.8H, v12.8H // ...e............................................ + smull v12.4S, v3.4H, v4.4H // .............e.................................. + smull2 v11.4S, v3.8H, v4.8H // ..............e................................. + ldr q5, [x5, #-16] // ......................e......................... + smlal v12.4S, v10.4H, v17.4H // ...............e................................ + smlal2 v11.4S, v10.8H, v17.8H // ................e............................... + uzp2 v14.8H, v30.8H, v5.8H // ........................e....................... + uzp1 v15.8H, v30.8H, v5.8H // .......................e........................ + smlal v12.4S, v13.4H, v14.4H // ..............................e................. + smlal2 v11.4S, v13.8H, v14.8H // ...............................e................ + uzp2 v28.8H, v18.8H, v19.8H // ....................e........................... + smlal v23.4S, v31.4H, v0.4H // ....................................*........... + smlal v12.4S, v28.4H, v15.4H // ................................e............... + smlal2 v11.4S, v28.8H, v15.8H // .................................e.............. + ld1 {v22.8H}, [x6], #16 // .........................e...................... + uzp2 v19.8H, v23.8H, v20.8H // ......................................*......... + uzp1 v1.8H, v12.8H, v11.8H // .......................................e........ + smull v23.4S, v3.4H, v17.4H // .........e...................................... + mul v14.8H, v1.8H, v2.8H // ........................................e....... + zip2 v9.8H, v19.8H, v27.8H // .............................................*.. + ld1 {v4.8H}, [x3], #16 // ........e....................................... + smlal2 v11.4S, v14.8H, v0.8H // ..........................................e..... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q12, [x1], #32 // ....................e..........................'....................~..........................'.................. + // ldr q13, [x1, #-16] // .....................e.........................'.....................~.........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................e......................'........................~......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // .........................e.....................'.........................~.....................'.................. + // ldr q12, [x2], #32 // .......e.......................................'.......~.......................................'.......~.......... + // ldr q13, [x2, #-16] // ...........e...................................'...........~...................................'...........~...... + // uzp1 v5.8h, v12.8h, v13.8h // ..............e................................'..............~................................'..............~... + // uzp2 v6.8h, v12.8h, v13.8h // ...............e...............................'...............~...............................'...............~.. + // ld1 {v7.8h}, [x3], #16 // .............................................e.'.............................................~.'.................. + // smull v8.4s, v3.4h, v5.4h // ..........................................e....'..........................................~....'.................. + // smull2 v10.4s, v3.8h, v5.8h // ...............................................*...............................................~.................. + // smlal v8.4s, v4.4h, v7.4h // ....~..........................................'....*..........................................'....~............. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................'..*............................................'..~............... + // smull v9.4s, v3.4h, v6.4h // ..........................e....................'..........................~....................'.................. + // smull2 v11.4s, v3.8h, v6.8h // ...........................e...................'...........................~...................'.................. + // smlal v9.4s, v4.4h, v5.4h // .............................e.................'.............................~.................'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ..............................e................'..............................~................'.................. + // ldr q12, [x4], #32 // e..............................................'~..............................................'~................. + // ldr q13, [x4, #-16] // ...................e...........................'...................~...........................'.................. + // uzp1 v3.8h, v12.8h, v13.8h // .......................e.......................'.......................~.......................'.................. + // uzp2 v4.8h, v12.8h, v13.8h // ...................................e...........'...................................~...........'.................. + // ldr q12, [x5], #32 // .e.............................................'.~.............................................'.~................ + // ldr q13, [x5, #-16] // ............................e..................'............................~..................'.................. + // uzp1 v5.8h, v12.8h, v13.8h // ................................e..............'................................~..............'.................. + // uzp2 v6.8h, v12.8h, v13.8h // ...............................e...............'...............................~...............'.................. + // ld1 {v7.8h}, [x6], #16 // .......................................e.......'.......................................~.......'.................. + // smlal v8.4s, v3.4h, v5.4h // ........~......................................'........*......................................'........~......... + // smlal2 v10.4s, v3.8h, v5.8h // ......~........................................'......*........................................'......~........... + // smlal v8.4s, v4.4h, v7.4h // ............~..................................'............*..................................'............~..... + // smlal2 v10.4s, v4.8h, v7.8h // .........~.....................................'.........*.....................................'.........~........ + // smlal v9.4s, v3.4h, v6.4h // .................................e.............'.................................~.............'.................. + // smlal2 v11.4s, v3.8h, v6.8h // ..................................e............'..................................~............'.................. + // smlal v9.4s, v4.4h, v5.4h // .....................................e.........'.....................................~.........'.................. + // smlal2 v11.4s, v4.8h, v5.8h // ......................................e........'......................................~........'.................. + // uzp1 v28.8h, v8.8h, v10.8h // ................~..............................'................*..............................'................~. + // mul v28.8h, v28.8h, v2.8h // ..................~............................'..................*............................'.................. + // smlal v8.4s, v28.4h, v0.4h // ....................................~..........'....................................*..........'.................. + // smlal2 v10.4s, v28.8h, v0.8h // ......................~........................'......................*........................'.................. + // uzp2 v26.8h, v8.8h, v10.8h // ........................................~......'........................................*......'.................. + // uzp1 v28.8h, v9.8h, v11.8h // .........................................e.....'.........................................~.....'.................. + // mul v28.8h, v28.8h, v2.8h // ...........................................e...'...........................................~...'.................. + // smlal v9.4s, v28.4h, v0.4h // ...~...........................................'...*...........................................'...~.............. + // smlal2 v11.4s, v28.8h, v0.8h // ..............................................e'..............................................~'.................. + // uzp2 v27.8h, v9.8h, v11.8h // .............~.................................'.............*.................................'.............~.... + // zip1 v12.8h, v26.8h, v27.8h // ..........~....................................'..........~....................................'..........l....... + // zip2 v13.8h, v26.8h, v27.8h // ............................................~..'............................................*..'.................. + // str q12, [x0], #32 // .................~.............................'.................~.............................'.................l + // str q13, [x0, #-16] // .....~.........................................'.....~.........................................'.....l............ + + sub count, count, #1 + cbnz count, 1b + // Instructions: 21 + // Expected cycles: 35 + // Expected IPC: 0.60 + + // Cycle bound: 35.0 + // IPC bound: 0.60 + + // Wall time: 0.08s + // User time: 0.08s + + // ----- original position -----> + // 0 25 + // |------------------------|---- + smull2 v5.4S, v3.8H, v17.8H // *............................. + smlal v12.4S, v14.4H, v0.4H // ..*........................... + smlal v23.4S, v10.4H, v4.4H // ...*.......................... + str q9, [x0, #16] // ....*......................... + smlal2 v5.4S, v10.8H, v4.8H // .*............................ + uzp2 v11.8H, v12.8H, v11.8H // ..........*................... + zip1 v9.8H, v19.8H, v27.8H // ........*..................... + smlal v23.4S, v13.4H, v15.4H // ......*....................... + smlal2 v5.4S, v13.8H, v15.8H // .....*........................ + str q9, [x0], #32 // ............*................. + smlal v23.4S, v28.4H, v22.4H // .........*.................... + smlal2 v5.4S, v28.8H, v22.8H // .......*...................... + uzp1 v9.8H, v23.8H, v5.8H // ...........*.................. + mul v9.8H, v9.8H, v2.8H // .............*................ + smlal2 v5.4S, v9.8H, v0.8H // ..............*............... + smlal v23.4S, v9.4H, v0.4H // ...............*.............. + uzp2 v9.8H, v23.8H, v5.8H // ................*............. + zip2 v5.8H, v9.8H, v11.8H // .................*............ + zip1 v9.8H, v9.8H, v11.8H // ...................*.......... + str q5, [x0, #16] // ..................*........... + str q9, [x0], #32 // ....................*......... + + // -------- new position --------> + // 0 25 + // |------------------------|----- + // smull2 v20.4S, v3.8H, v17.8H // *.............................. + // smlal2 v20.4S, v10.8H, v4.8H // ....*.......................... + // smlal v12.4S, v14.4H, v0.4H // .*............................. + // smlal v23.4S, v10.4H, v4.4H // ..*............................ + // str q9, [x0, #16] // ...*........................... + // smlal2 v20.4S, v13.8H, v15.8H // ........*...................... + // smlal v23.4S, v13.4H, v15.4H // .......*....................... + // smlal2 v20.4S, v28.8H, v22.8H // ...........*................... + // zip1 v26.8H, v19.8H, v27.8H // ......*........................ + // smlal v23.4S, v28.4H, v22.4H // ..........*.................... + // uzp2 v27.8H, v12.8H, v11.8H // .....*......................... + // uzp1 v5.8H, v23.8H, v20.8H // ............*.................. + // str q26, [x0], #32 // .........*..................... + // mul v31.8H, v5.8H, v2.8H // .............*................. + // smlal2 v20.4S, v31.8H, v0.8H // ..............*................ + // smlal v23.4S, v31.4H, v0.4H // ...............*............... + // uzp2 v19.8H, v23.8H, v20.8H // ................*.............. + // zip2 v9.8H, v19.8H, v27.8H // .................*............. + // str q9, [x0, #16] // ...................*........... + // zip1 v26.8H, v19.8H, v27.8H // ..................*............ + // str q26, [x0], #32 // ....................*.......... + + + pop_stack + ret +#endif /* MLKEM_K == 2 */ + +#if MLKEM_K == 3 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + + mov count, #(MLKEM_N / 16) + // Instructions: 75 + // Expected cycles: 103 + // Expected IPC: 0.73 + + // Cycle bound: 103.0 + // IPC bound: 0.73 + + // Wall time: 0.94s + // User time: 0.94s + + // --------------------------- original position ----------------------------> + // 0 25 50 + // |------------------------|------------------------| + ldr q7, [x2, #16] // *.......................................................................... + ldr q20, [x2], #32 // ..*........................................................................ + ldr q15, [x1, #16] // .*......................................................................... + uzp1 v8.8H, v20.8H, v7.8H // ...............*........................................................... + uzp2 v7.8H, v20.8H, v7.8H // ................*.......................................................... + ld1 {v20.8H}, [x3], #16 // ...*....................................................................... + ldr q30, [x1], #32 // ..............*............................................................ + ldr q11, [x4], #32 // ....*...................................................................... + uzp1 v16.8H, v30.8H, v15.8H // .................*......................................................... + uzp2 v15.8H, v30.8H, v15.8H // ..................*........................................................ + smull v30.4S, v16.4H, v7.4H // ...................*....................................................... + smull2 v7.4S, v16.8H, v7.8H // ....................*...................................................... + smull v9.4S, v16.4H, v8.4H // .....................*..................................................... + smull2 v16.4S, v16.8H, v8.8H // ......................*.................................................... + smlal v30.4S, v15.4H, v8.4H // .......................*................................................... + smlal2 v7.4S, v15.8H, v8.8H // ........................*.................................................. + smlal v9.4S, v15.4H, v20.4H // .........................*................................................. + smlal2 v16.4S, v15.8H, v20.8H // ..........................*................................................ + ldr q20, [x4, #-16] // .....*..................................................................... + ldr q15, [x5], #32 // ......*.................................................................... + uzp1 v8.8H, v11.8H, v20.8H // ...........................*............................................... + uzp2 v20.8H, v11.8H, v20.8H // ............................*.............................................. + ldr q11, [x5, #-16] // .......*................................................................... + ld1 {v27.8H}, [x6], #16 // ........*.................................................................. + uzp1 v10.8H, v15.8H, v11.8H // .............................*............................................. + uzp2 v15.8H, v15.8H, v11.8H // ..............................*............................................ + smlal v9.4S, v8.4H, v10.4H // ...............................*........................................... + smlal2 v16.4S, v8.8H, v10.8H // ................................*.......................................... + smlal v30.4S, v8.4H, v15.4H // .................................*......................................... + smlal2 v7.4S, v8.8H, v15.8H // ..................................*........................................ + smlal v9.4S, v20.4H, v27.4H // ...................................*....................................... + smlal2 v16.4S, v20.8H, v27.8H // ....................................*...................................... + smlal v30.4S, v20.4H, v10.4H // .....................................*..................................... + smlal2 v7.4S, v20.8H, v10.8H // ......................................*.................................... + ldr q20, [x7], #32 // .........*................................................................. + ldr q15, [x7, #-16] // ..........*................................................................ + ldr q8, [x8], #32 // ...........*............................................................... + uzp1 v11.8H, v20.8H, v15.8H // .......................................*................................... + uzp2 v20.8H, v20.8H, v15.8H // ........................................*.................................. + ldr q15, [x8, #-16] // ............*.............................................................. + ld1 {v27.8H}, [x9], #16 // .............*............................................................. + uzp1 v10.8H, v8.8H, v15.8H // .........................................*................................. + uzp2 v15.8H, v8.8H, v15.8H // ..........................................*................................ + smlal v9.4S, v11.4H, v10.4H // ...........................................*............................... + smlal2 v16.4S, v11.8H, v10.8H // ............................................*.............................. + smlal v30.4S, v11.4H, v15.4H // .............................................*............................. + smlal2 v7.4S, v11.8H, v15.8H // ..............................................*............................ + smlal v9.4S, v20.4H, v27.4H // ...............................................*........................... + smlal2 v16.4S, v20.8H, v27.8H // ................................................*.......................... + smlal v30.4S, v20.4H, v10.4H // .................................................*......................... + smlal2 v7.4S, v20.8H, v10.8H // ..................................................*........................ + ldr q15, [x2], #32 // ...............................................................*........... + uzp1 v20.8H, v9.8H, v16.8H // ....................................................*...................... + uzp1 v8.8H, v30.8H, v7.8H // .....................................................*..................... + mul v20.8H, v20.8H, v2.8H // ......................................................*.................... + mul v8.8H, v8.8H, v2.8H // .......................................................*................... + ldr q21, [x4], #32 // .................................................................*......... + smlal v9.4S, v20.4H, v0.4H // ........................................................*.................. + smlal2 v16.4S, v20.8H, v0.8H // .........................................................*................. + smlal v30.4S, v8.4H, v0.4H // ..........................................................*................ + smlal2 v7.4S, v8.8H, v0.8H // ...........................................................*............... + ldr q6, [x4, #-16] // ..................................................................*........ + uzp2 v27.8H, v9.8H, v16.8H // ............................................................*.............. + uzp2 v10.8H, v30.8H, v7.8H // .............................................................*............. + ldr q16, [x2, #-16] // ...................................................*....................... + ldr q30, [x1, #16] // ..............................................................*............ + ld1 {v9.8H}, [x3], #16 // ................................................................*.......... + ldr q1, [x5], #32 // ...................................................................*....... + ldr q12, [x5, #-16] // ....................................................................*...... + ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + ldr q19, [x7], #32 // ......................................................................*.... + ldr q31, [x7, #-16] // .......................................................................*... + ldr q17, [x8], #32 // ........................................................................*.. + ldr q18, [x8, #-16] // .........................................................................*. + ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + // ------------------------------ new position ------------------------------> + // 0 25 50 + // |------------------------|------------------------|------------------------ + // ldr q16, [x2, #16] // *.......................................................................... + // ldr q30, [x1, #16] // ..*........................................................................ + // ldr q15, [x2], #32 // .*......................................................................... + // ld1 {v9.8H}, [x3], #16 // .....*..................................................................... + // ldr q21, [x4], #32 // .......*................................................................... + // ldr q6, [x4, #-16] // ..................*........................................................ + // ldr q1, [x5], #32 // ...................*....................................................... + // ldr q12, [x5, #-16] // ......................*.................................................... + // ld1 {v24.8H}, [x6], #16 // .......................*................................................... + // ldr q19, [x7], #32 // ..................................*........................................ + // ldr q31, [x7, #-16] // ...................................*....................................... + // ldr q17, [x8], #32 // ....................................*...................................... + // ldr q18, [x8, #-16] // .......................................*................................... + // ld1 {v25.8H}, [x9], #16 // ........................................*.................................. + // ldr q20, [x1], #32 // ......*.................................................................... + // uzp1 v7.8H, v15.8H, v16.8H // ...*....................................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ....*...................................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ........*.................................................................. + // uzp2 v20.8H, v20.8H, v30.8H // .........*................................................................. + // smull v30.4S, v8.4H, v15.4H // ..........*................................................................ + // smull2 v15.4S, v8.8H, v15.8H // ...........*............................................................... + // smull v11.4S, v8.4H, v7.4H // ............*.............................................................. + // smull2 v8.4S, v8.8H, v7.8H // .............*............................................................. + // smlal v30.4S, v20.4H, v7.4H // ..............*............................................................ + // smlal2 v15.4S, v20.8H, v7.8H // ...............*........................................................... + // smlal v11.4S, v20.4H, v9.4H // ................*.......................................................... + // smlal2 v8.4S, v20.8H, v9.8H // .................*......................................................... + // uzp1 v7.8H, v21.8H, v6.8H // ....................*...................................................... + // uzp2 v20.8H, v21.8H, v6.8H // .....................*..................................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........................*.................................................. + // uzp2 v9.8H, v1.8H, v12.8H // .........................*................................................. + // smlal v11.4S, v7.4H, v16.4H // ..........................*................................................ + // smlal2 v8.4S, v7.8H, v16.8H // ...........................*............................................... + // smlal v30.4S, v7.4H, v9.4H // ............................*.............................................. + // smlal2 v15.4S, v7.8H, v9.8H // .............................*............................................. + // smlal v11.4S, v20.4H, v24.4H // ..............................*............................................ + // smlal2 v8.4S, v20.8H, v24.8H // ...............................*........................................... + // smlal v30.4S, v20.4H, v16.4H // ................................*.......................................... + // smlal2 v15.4S, v20.8H, v16.8H // .................................*......................................... + // uzp1 v7.8H, v19.8H, v31.8H // .....................................*..................................... + // uzp2 v20.8H, v19.8H, v31.8H // ......................................*.................................... + // uzp1 v16.8H, v17.8H, v18.8H // .........................................*................................. + // uzp2 v9.8H, v17.8H, v18.8H // ..........................................*................................ + // smlal v11.4S, v7.4H, v16.4H // ...........................................*............................... + // smlal2 v8.4S, v7.8H, v16.8H // ............................................*.............................. + // smlal v30.4S, v7.4H, v9.4H // .............................................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ..............................................*............................ + // smlal v11.4S, v20.4H, v25.4H // ...............................................*........................... + // smlal2 v8.4S, v20.8H, v25.8H // ................................................*.......................... + // smlal v30.4S, v20.4H, v16.4H // .................................................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ..................................................*........................ + // ldr q16, [x2, #16] // ................................................................*.......... + // uzp1 v7.8H, v11.8H, v8.8H // ....................................................*...................... + // uzp1 v20.8H, v30.8H, v15.8H // .....................................................*..................... + // mul v7.8H, v7.8H, v2.8H // ......................................................*.................... + // mul v20.8H, v20.8H, v2.8H // .......................................................*................... + // smlal v11.4S, v7.4H, v0.4H // .........................................................*................. + // smlal2 v8.4S, v7.8H, v0.8H // ..........................................................*................ + // smlal v30.4S, v20.4H, v0.4H // ...........................................................*............... + // smlal2 v15.4S, v20.8H, v0.8H // ............................................................*.............. + // uzp2 v27.8H, v11.8H, v8.8H // ..............................................................*............ + // uzp2 v10.8H, v30.8H, v15.8H // ...............................................................*........... + // ldr q30, [x1, #16] // .................................................................*......... + // ldr q15, [x2], #32 // ...................................................*....................... + // ld1 {v9.8H}, [x3], #16 // ..................................................................*........ + // ldr q21, [x4], #32 // ........................................................*.................. + // ldr q6, [x4, #-16] // .............................................................*............. + // ldr q1, [x5], #32 // ...................................................................*....... + // ldr q12, [x5, #-16] // ....................................................................*...... + // ld1 {v24.8H}, [x6], #16 // .....................................................................*..... + // ldr q19, [x7], #32 // ......................................................................*.... + // ldr q31, [x7, #-16] // .......................................................................*... + // ldr q17, [x8], #32 // ........................................................................*.. + // ldr q18, [x8, #-16] // .........................................................................*. + // ld1 {v25.8H}, [x9], #16 // ..........................................................................* + + sub count, count, #2 +1: + // Instructions: 65 + // Expected cycles: 80 + // Expected IPC: 0.81 + + // Cycle bound: 80.0 + // IPC bound: 0.81 + + // Wall time: 11.64s + // User time: 11.64s + + // ---------------------- original position -----------------------> + // 0 25 50 + // |------------------------|------------------------|-------------- + ldr q20, [x1], #32 // *................................................................ + uzp1 v7.8H, v15.8H, v16.8H // ......*.......................................................... + uzp2 v15.8H, v15.8H, v16.8H // .......*......................................................... + uzp1 v8.8H, v20.8H, v30.8H // ..*.............................................................. + uzp2 v20.8H, v20.8H, v30.8H // ...*............................................................. + smull v30.4S, v8.4H, v15.4H // .............*................................................... + smull2 v15.4S, v8.8H, v15.8H // ..............*.................................................. + smull v11.4S, v8.4H, v7.4H // .........*....................................................... + smull2 v8.4S, v8.8H, v7.8H // ..........*...................................................... + smlal v30.4S, v20.4H, v7.4H // ...............*................................................. + smlal2 v15.4S, v20.8H, v7.8H // ................*................................................ + smlal v11.4S, v20.4H, v9.4H // ...........*..................................................... + smlal2 v8.4S, v20.8H, v9.8H // ............*.................................................... + uzp1 v7.8H, v21.8H, v6.8H // ...................*............................................. + uzp2 v20.8H, v21.8H, v6.8H // ....................*............................................ + uzp1 v16.8H, v1.8H, v12.8H // .......................*......................................... + uzp2 v9.8H, v1.8H, v12.8H // ........................*........................................ + smlal v11.4S, v7.4H, v16.4H // ..........................*...................................... + smlal2 v8.4S, v7.8H, v16.8H // ...........................*..................................... + smlal v30.4S, v7.4H, v9.4H // ..............................*.................................. + smlal2 v15.4S, v7.8H, v9.8H // ...............................*................................. + smlal v11.4S, v20.4H, v24.4H // ............................*.................................... + smlal2 v8.4S, v20.8H, v24.8H // .............................*................................... + smlal v30.4S, v20.4H, v16.4H // ................................*................................ + smlal2 v15.4S, v20.8H, v16.8H // .................................*............................... + uzp1 v7.8H, v19.8H, v31.8H // ....................................*............................ + uzp2 v20.8H, v19.8H, v31.8H // .....................................*........................... + uzp1 v16.8H, v17.8H, v18.8H // ........................................*........................ + uzp2 v9.8H, v17.8H, v18.8H // .........................................*....................... + smlal v11.4S, v7.4H, v16.4H // ...........................................*..................... + smlal2 v8.4S, v7.8H, v16.8H // ............................................*.................... + smlal v30.4S, v7.4H, v9.4H // ...............................................*................. + smlal2 v15.4S, v7.8H, v9.8H // ................................................*................ + smlal v11.4S, v20.4H, v25.4H // .............................................*................... + smlal2 v8.4S, v20.8H, v25.8H // ..............................................*.................. + smlal v30.4S, v20.4H, v16.4H // .................................................*............... + smlal2 v15.4S, v20.8H, v16.8H // ..................................................*.............. + ldr q16, [x2, #16] // .....e........................................................... + uzp1 v7.8H, v11.8H, v8.8H // ...................................................*............. + uzp1 v20.8H, v30.8H, v15.8H // ........................................................*........ + mul v7.8H, v7.8H, v2.8H // ....................................................*............ + mul v20.8H, v20.8H, v2.8H // .........................................................*....... + zip2 v9.8H, v27.8H, v10.8H // ..............................................................l.. + zip1 v27.8H, v27.8H, v10.8H // .............................................................l... + smlal v11.4S, v7.4H, v0.4H // .....................................................*........... + smlal2 v8.4S, v7.8H, v0.8H // ......................................................*.......... + smlal v30.4S, v20.4H, v0.4H // ..........................................................*...... + smlal2 v15.4S, v20.8H, v0.8H // ...........................................................*..... + str q27, [x0], #32 // ...............................................................l. + uzp2 v27.8H, v11.8H, v8.8H // .......................................................*......... + str q9, [x0, #-16] // ................................................................l + uzp2 v10.8H, v30.8H, v15.8H // ............................................................*.... + ldr q30, [x1, #16] // .e............................................................... + ldr q15, [x2], #32 // ....e............................................................ + ld1 {v9.8H}, [x3], #16 // ........e........................................................ + ldr q21, [x4], #32 // .................e............................................... + ldr q6, [x4, #-16] // ..................e.............................................. + ldr q1, [x5], #32 // .....................e........................................... + ldr q12, [x5, #-16] // ......................e.......................................... + ld1 {v24.8H}, [x6], #16 // .........................e....................................... + ldr q19, [x7], #32 // ..................................e.............................. + ldr q31, [x7, #-16] // ...................................e............................. + ldr q17, [x8], #32 // ......................................e.......................... + ldr q18, [x8, #-16] // .......................................e......................... + ld1 {v25.8H}, [x9], #16 // ..........................................e...................... + + // ---------------------------------------------------------------- new position -----------------------------------------------------------------> + // 0 25 50 75 100 125 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------ + // ldr q12, [x1], #32 // ............................*................................................................~.................................................. + // ldr q13, [x1, #-16] // ...............e............'...................................................~............'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'..*.............................................................'..~............................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'...*............................................................'...~.............................................. + // ldr q12, [x2], #32 // ................e...........'....................................................~...........'.................................................. + // ldr q13, [x2, #-16] // e...........................'....................................~...........................'....................................~............. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'*...............................................................'~................................................. + // uzp2 v6.8h, v12.8h, v13.8h // ............................'.*..............................................................'.~................................................ + // ld1 {v7.8h}, [x3], #16 // .................e..........'.....................................................~..........'.................................................. + // smull v8.4s, v3.4h, v5.4h // ............................'......*.........................................................'......~........................................... + // smull2 v10.4s, v3.8h, v5.8h // ............................'.......*........................................................'.......~.......................................... + // smlal v8.4s, v4.4h, v7.4h // ............................'..........*.....................................................'..........~....................................... + // smlal2 v10.4s, v4.8h, v7.8h // ............................'...........*....................................................'...........~...................................... + // smull v9.4s, v3.4h, v6.4h // ............................'....*...........................................................'....~............................................. + // smull2 v11.4s, v3.8h, v6.8h // ............................'.....*..........................................................'.....~............................................ + // smlal v9.4s, v4.4h, v5.4h // ............................'........*.......................................................'........~......................................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.........*......................................................'.........~........................................ + // ldr q12, [x4], #32 // ..................e.........'......................................................~.........'.................................................. + // ldr q13, [x4, #-16] // ...................e........'.......................................................~........'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'............*...................................................'............~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.............*..................................................'.............~.................................... + // ldr q12, [x5], #32 // ....................e.......'........................................................~.......'.................................................. + // ldr q13, [x5, #-16] // .....................e......'.........................................................~......'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..............*.................................................'..............~................................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...............*................................................'...............~.................................. + // ld1 {v7.8h}, [x6], #16 // ......................e.....'..........................................................~.....'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'................*...............................................'................~................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.................*..............................................'.................~................................ + // smlal v8.4s, v4.4h, v7.4h // ............................'....................*...........................................'....................~............................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.....................*..........................................'.....................~............................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..................*.............................................'..................~............................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...................*............................................'...................~.............................. + // smlal v9.4s, v4.4h, v5.4h // ............................'......................*.........................................'......................~........................... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'.......................*........................................'.......................~.......................... + // ldr q12, [x7], #32 // .......................e....'...........................................................~....'.................................................. + // ldr q13, [x7, #-16] // ........................e...'............................................................~...'.................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ............................'........................*.......................................'........................~......................... + // uzp2 v4.8h, v12.8h, v13.8h // ............................'.........................*......................................'.........................~........................ + // ldr q12, [x8], #32 // .........................e..'.............................................................~..'.................................................. + // ldr q13, [x8, #-16] // ..........................e.'..............................................................~.'.................................................. + // uzp1 v5.8h, v12.8h, v13.8h // ............................'..........................*.....................................'..........................~....................... + // uzp2 v6.8h, v12.8h, v13.8h // ............................'...........................*....................................'...........................~...................... + // ld1 {v7.8h}, [x9], #16 // ...........................e'...............................................................~'.................................................. + // smlal v8.4s, v3.4h, v5.4h // ............................'............................*...................................'............................~..................... + // smlal2 v10.4s, v3.8h, v5.8h // ............................'.............................*..................................'.............................~.................... + // smlal v8.4s, v4.4h, v7.4h // ............................'................................*...............................'................................~................. + // smlal2 v10.4s, v4.8h, v7.8h // ............................'.................................*..............................'.................................~................ + // smlal v9.4s, v3.4h, v6.4h // ............................'..............................*.................................'..............................~................... + // smlal2 v11.4s, v3.8h, v6.8h // ............................'...............................*................................'...............................~.................. + // smlal v9.4s, v4.4h, v5.4h // ............................'..................................*.............................'..................................~............... + // smlal2 v11.4s, v4.8h, v5.8h // ............................'...................................*............................'...................................~.............. + // uzp1 v28.8h, v8.8h, v10.8h // .~..........................'.....................................*..........................'.....................................~............ + // mul v28.8h, v28.8h, v2.8h // ...~........................'.......................................*........................'.......................................~.......... + // smlal v8.4s, v28.4h, v0.4h // .......~....................'...........................................*....................'...........................................~...... + // smlal2 v10.4s, v28.8h, v0.8h // ........~...................'............................................*...................'............................................~..... + // uzp2 v26.8h, v8.8h, v10.8h // ............~...............'................................................*...............'................................................~. + // uzp1 v28.8h, v9.8h, v11.8h // ..~.........................'......................................*.........................'......................................~........... + // mul v28.8h, v28.8h, v2.8h // ....~.......................'........................................*.......................'........................................~......... + // smlal v9.4s, v28.4h, v0.4h // .........~..................'.............................................*..................'.............................................~.... + // smlal2 v11.4s, v28.8h, v0.8h // ..........~.................'..............................................*.................'..............................................~... + // uzp2 v27.8h, v9.8h, v11.8h // ..............~.............'..................................................*.............'.................................................. + // zip1 v12.8h, v26.8h, v27.8h // ......~.....................'..........................................~.....................'..........................................l....... + // zip2 v13.8h, v26.8h, v27.8h // .....~......................'.........................................~......................'.........................................l........ + // str q12, [x0], #32 // ...........~................'...............................................~................'...............................................l.. + // str q13, [x0, #-16] // .............~..............'.................................................~..............'.................................................l + + sub count, count, #1 + cbnz count, 1b + // Instructions: 55 + // Expected cycles: 61 + // Expected IPC: 0.90 + + // Cycle bound: 61.0 + // IPC bound: 0.90 + + // Wall time: 8.41s + // User time: 8.41s + + // ----------------- original position ------------------> + // 0 25 50 + // |------------------------|------------------------|---- + ldr q7, [x1], #32 // *...................................................... + uzp1 v20.8H, v15.8H, v16.8H // .*..................................................... + uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + uzp1 v23.8H, v7.8H, v30.8H // ...*................................................... + uzp2 v11.8H, v7.8H, v30.8H // ....*.................................................. + smull2 v8.4S, v23.8H, v20.8H // ........*.............................................. + smull v5.4S, v23.4H, v20.4H // .......*............................................... + smull2 v30.4S, v23.8H, v15.8H // ......*................................................ + uzp1 v28.8H, v1.8H, v12.8H // ...............*....................................... + smlal2 v8.4S, v11.8H, v9.8H // ............*.......................................... + smlal v5.4S, v11.4H, v9.4H // ...........*........................................... + uzp1 v3.8H, v21.8H, v6.8H // .............*......................................... + smull v16.4S, v23.4H, v15.4H // .....*................................................. + smlal2 v8.4S, v3.8H, v28.8H // ..................*.................................... + smlal v5.4S, v3.4H, v28.4H // .................*..................................... + uzp2 v29.8H, v21.8H, v6.8H // ..............*........................................ + uzp1 v7.8H, v17.8H, v18.8H // ...........................*........................... + smlal2 v8.4S, v29.8H, v24.8H // ......................*................................ + uzp1 v14.8H, v19.8H, v31.8H // .........................*............................. + smlal v16.4S, v11.4H, v20.4H // .........*............................................. + smlal2 v30.4S, v11.8H, v20.8H // ..........*............................................ + smlal2 v8.4S, v14.8H, v7.8H // ..............................*........................ + uzp2 v20.8H, v1.8H, v12.8H // ................*...................................... + uzp2 v21.8H, v19.8H, v31.8H // ..........................*............................ + smlal2 v30.4S, v3.8H, v20.8H // ....................*.................................. + smlal v16.4S, v3.4H, v20.4H // ...................*................................... + smlal v5.4S, v29.4H, v24.4H // .....................*................................. + uzp2 v9.8H, v17.8H, v18.8H // ............................*.......................... + smlal2 v30.4S, v29.8H, v28.8H // ........................*.............................. + smlal v16.4S, v29.4H, v28.4H // .......................*............................... + smlal v5.4S, v14.4H, v7.4H // .............................*......................... + smlal2 v8.4S, v21.8H, v25.8H // ..................................*.................... + smlal2 v30.4S, v14.8H, v9.8H // ................................*...................... + smlal v16.4S, v14.4H, v9.4H // ...............................*....................... + smlal v5.4S, v21.4H, v25.4H // .................................*..................... + zip1 v20.8H, v27.8H, v10.8H // ..........................................*............ + smlal2 v30.4S, v21.8H, v7.8H // ....................................*.................. + smlal v16.4S, v21.4H, v7.4H // ...................................*................... + uzp1 v7.8H, v5.8H, v8.8H // .....................................*................. + str q20, [x0], #32 // ...............................................*....... + mul v15.8H, v7.8H, v2.8H // .......................................*............... + uzp1 v7.8H, v16.8H, v30.8H // ......................................*................ + zip2 v31.8H, v27.8H, v10.8H // .........................................*............. + mul v20.8H, v7.8H, v2.8H // ........................................*.............. + smlal v5.4S, v15.4H, v0.4H // ...........................................*........... + smlal2 v8.4S, v15.8H, v0.8H // ............................................*.......... + str q31, [x0, #-16] // .................................................*..... + smlal2 v30.4S, v20.8H, v0.8H // ..............................................*........ + smlal v16.4S, v20.4H, v0.4H // .............................................*......... + uzp2 v15.8H, v5.8H, v8.8H // ................................................*...... + uzp2 v20.8H, v16.8H, v30.8H // ..................................................*.... + zip1 v7.8H, v15.8H, v20.8H // ....................................................*.. + zip2 v20.8H, v15.8H, v20.8H // ...................................................*... + str q7, [x0], #32 // .....................................................*. + str q20, [x0, #-16] // ......................................................* + + // -------------------- new position --------------------> + // 0 25 50 + // |------------------------|------------------------|---- + // ldr q20, [x1], #32 // *...................................................... + // uzp1 v7.8H, v15.8H, v16.8H // .*..................................................... + // uzp2 v15.8H, v15.8H, v16.8H // ..*.................................................... + // uzp1 v8.8H, v20.8H, v30.8H // ...*................................................... + // uzp2 v20.8H, v20.8H, v30.8H // ....*.................................................. + // smull v30.4S, v8.4H, v15.4H // ............*.......................................... + // smull2 v15.4S, v8.8H, v15.8H // .......*............................................... + // smull v11.4S, v8.4H, v7.4H // ......*................................................ + // smull2 v8.4S, v8.8H, v7.8H // .....*................................................. + // smlal v30.4S, v20.4H, v7.4H // ...................*................................... + // smlal2 v15.4S, v20.8H, v7.8H // ....................*.................................. + // smlal v11.4S, v20.4H, v9.4H // ..........*............................................ + // smlal2 v8.4S, v20.8H, v9.8H // .........*............................................. + // uzp1 v7.8H, v21.8H, v6.8H // ...........*........................................... + // uzp2 v20.8H, v21.8H, v6.8H // ...............*....................................... + // uzp1 v16.8H, v1.8H, v12.8H // ........*.............................................. + // uzp2 v9.8H, v1.8H, v12.8H // ......................*................................ + // smlal v11.4S, v7.4H, v16.4H // ..............*........................................ + // smlal2 v8.4S, v7.8H, v16.8H // .............*......................................... + // smlal v30.4S, v7.4H, v9.4H // .........................*............................. + // smlal2 v15.4S, v7.8H, v9.8H // ........................*.............................. + // smlal v11.4S, v20.4H, v24.4H // ..........................*............................ + // smlal2 v8.4S, v20.8H, v24.8H // .................*..................................... + // smlal v30.4S, v20.4H, v16.4H // .............................*......................... + // smlal2 v15.4S, v20.8H, v16.8H // ............................*.......................... + // uzp1 v7.8H, v19.8H, v31.8H // ..................*.................................... + // uzp2 v20.8H, v19.8H, v31.8H // .......................*............................... + // uzp1 v16.8H, v17.8H, v18.8H // ................*...................................... + // uzp2 v9.8H, v17.8H, v18.8H // ...........................*........................... + // smlal v11.4S, v7.4H, v16.4H // ..............................*........................ + // smlal2 v8.4S, v7.8H, v16.8H // .....................*................................. + // smlal v30.4S, v7.4H, v9.4H // .................................*..................... + // smlal2 v15.4S, v7.8H, v9.8H // ................................*...................... + // smlal v11.4S, v20.4H, v25.4H // ..................................*.................... + // smlal2 v8.4S, v20.8H, v25.8H // ...............................*....................... + // smlal v30.4S, v20.4H, v16.4H // .....................................*................. + // smlal2 v15.4S, v20.8H, v16.8H // ....................................*.................. + // uzp1 v7.8H, v11.8H, v8.8H // ......................................*................ + // uzp1 v20.8H, v30.8H, v15.8H // .........................................*............. + // mul v7.8H, v7.8H, v2.8H // ........................................*.............. + // mul v20.8H, v20.8H, v2.8H // ...........................................*........... + // zip2 v9.8H, v27.8H, v10.8H // ..........................................*............ + // zip1 v27.8H, v27.8H, v10.8H // ...................................*................... + // smlal v11.4S, v7.4H, v0.4H // ............................................*.......... + // smlal2 v8.4S, v7.8H, v0.8H // .............................................*......... + // smlal v30.4S, v20.4H, v0.4H // ................................................*...... + // smlal2 v15.4S, v20.8H, v0.8H // ...............................................*....... + // str q27, [x0], #32 // .......................................*............... + // uzp2 v27.8H, v11.8H, v8.8H // .................................................*..... + // str q9, [x0, #-16] // ..............................................*........ + // uzp2 v10.8H, v30.8H, v15.8H // ..................................................*.... + // zip2 v9.8H, v27.8H, v10.8H // ....................................................*.. + // zip1 v27.8H, v27.8H, v10.8H // ...................................................*... + // str q27, [x0], #32 // .....................................................*. + // str q9, [x0, #-16] // ......................................................* + + + pop_stack + ret +#endif /* MLKEM_K == 3 */ + +#if MLKEM_K == 4 +.global MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt) + +MLKEM_ASM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_asm_opt): + push_stack + ldr q_modulus, c_modulus + ldr q_modulus_twisted, c_modulus_twisted + + // Computed bases of vector entries + + add a1_ptr, a0_ptr, #(1 * 512) + add b1_ptr, b0_ptr, #(1 * 512) + add b1_cache_ptr, b0_cache_ptr, #(1 * 512/2) + add a2_ptr, a0_ptr, #(2 * 512) + add b2_ptr, b0_ptr, #(2 * 512) + add b2_cache_ptr, b0_cache_ptr, #(2 * 512/2) + add a3_ptr, a0_ptr, #(3 * 512) + add b3_ptr, b0_ptr, #(3 * 512) + add b3_cache_ptr, b0_cache_ptr, #(3 * 512/2) + + // Bounds: + + // Each pmull is bound by 2*4096*2^15=2^28, so the final value + // before Montgomery reduction is bound by 2^30. + + mov count, #(MLKEM_N / 16) + // Instructions: 114 + // Expected cycles: 153 + // Expected IPC: 0.75 + // + // Cycle bound: 153.0 + // IPC bound: 0.75 + // + // Wall time: 0.69s + // User time: 0.69s + // + // ----------------------------------------------- original position -----------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + ldr q23, [x2, #16] // .*................................................................................................................ + ldr q19, [x2], #32 // *................................................................................................................. + ldr q17, [x5], #32 // ..*............................................................................................................... + uzp2 v13.8H, v19.8H, v23.8H // ..........*....................................................................................................... + uzp1 v19.8H, v19.8H, v23.8H // ...........*...................................................................................................... + ldr q23, [x5, #-16] // ...*.............................................................................................................. + ldr q30, [x1, #16] // .....*............................................................................................................ + uzp2 v9.8H, v17.8H, v23.8H // ....*............................................................................................................. + uzp1 v23.8H, v17.8H, v23.8H // .......*.......................................................................................................... + ldr q17, [x1], #32 // ......*........................................................................................................... + ldr q10, [x7, #16] // .............*.................................................................................................... + uzp1 v12.8H, v17.8H, v30.8H // ........*......................................................................................................... + uzp2 v17.8H, v17.8H, v30.8H // .........*........................................................................................................ + smull2 v30.4S, v12.8H, v13.8H // ............*..................................................................................................... + smull v13.4S, v12.4H, v13.4H // ............................................*..................................................................... + smull2 v22.4S, v12.8H, v19.8H // .....................................*............................................................................ + smull v12.4S, v12.4H, v19.4H // ..........................................*....................................................................... + smlal2 v30.4S, v17.8H, v19.8H // ...............................*.................................................................................. + smlal v13.4S, v17.4H, v19.4H // ...............................................*.................................................................. + ldr q19, [x4], #32 // ....................*............................................................................................. + ldr q16, [x4, #-16] // .....................*............................................................................................ + ld1 {v8.8H}, [x3], #16 // ................................*................................................................................. + uzp1 v26.8H, v19.8H, v16.8H // .......................*.......................................................................................... + uzp2 v19.8H, v19.8H, v16.8H // ........................*......................................................................................... + smlal2 v30.4S, v26.8H, v9.8H // .................................*................................................................................ + smlal v13.4S, v26.4H, v9.4H // ..................................................*............................................................... + smlal2 v22.4S, v17.8H, v8.8H // ........................................*......................................................................... + smlal v12.4S, v17.4H, v8.4H // .................................................*................................................................ + smlal2 v30.4S, v19.8H, v23.8H // ...................................*.............................................................................. + smlal v13.4S, v19.4H, v23.4H // .......................................................*.......................................................... + smlal2 v22.4S, v26.8H, v23.8H // ...........................................*...................................................................... + smlal v12.4S, v26.4H, v23.4H // .....................................................*............................................................ + ldr q23, [x7], #32 // ......................*........................................................................................... + ldr q17, [x8, #16] // ..............*................................................................................................... + uzp1 v9.8H, v23.8H, v10.8H // ..........................*....................................................................................... + uzp2 v23.8H, v23.8H, v10.8H // ....................................*............................................................................. + ldr q10, [x10], #32 // ...............*.................................................................................................. + ldr q16, [x10, #-16] // ................*................................................................................................. + ld1 {v8.8H}, [x12], #16 // .................*................................................................................................ + uzp1 v26.8H, v10.8H, v16.8H // ..................*............................................................................................... + uzp2 v10.8H, v10.8H, v16.8H // ...................*.............................................................................................. + ld1 {v16.8H}, [x6], #16 // .........................*........................................................................................ + ldr q3, [x11, #16] // ...........................*...................................................................................... + smlal2 v22.4S, v19.8H, v16.8H // ..............................................*................................................................... + smlal v12.4S, v19.4H, v16.4H // ........................................................*......................................................... + ldr q19, [x11], #32 // ............................*..................................................................................... + ld1 {v16.8H}, [x9], #16 // .............................*.................................................................................... + uzp1 v4.8H, v19.8H, v3.8H // ..................................*............................................................................... + uzp2 v19.8H, v19.8H, v3.8H // .......................................*.......................................................................... + ldr q3, [x8], #32 // ..............................*................................................................................... + ldr q31, [x2], #32 // ......................................*........................................................................... + uzp1 v6.8H, v3.8H, v17.8H // ...................................................*.............................................................. + uzp2 v17.8H, v3.8H, v17.8H // .........................................................*........................................................ + smlal2 v22.4S, v9.8H, v6.8H // ..........................................................*....................................................... + smlal2 v30.4S, v9.8H, v17.8H // ...........................................................*...................................................... + smlal v13.4S, v9.4H, v17.4H // ............................................................*..................................................... + smlal v12.4S, v9.4H, v6.4H // .............................................................*.................................................... + smlal2 v22.4S, v23.8H, v16.8H // ..............................................................*................................................... + smlal2 v30.4S, v23.8H, v6.8H // ...............................................................*.................................................. + smlal v13.4S, v23.4H, v6.4H // ................................................................*................................................. + smlal v12.4S, v23.4H, v16.4H // .................................................................*................................................ + smlal2 v22.4S, v26.8H, v4.8H // ..................................................................*............................................... + smlal2 v30.4S, v26.8H, v19.8H // ...................................................................*.............................................. + smlal v13.4S, v26.4H, v19.4H // ....................................................................*............................................. + smlal v12.4S, v26.4H, v4.4H // .....................................................................*............................................ + smlal2 v22.4S, v10.8H, v8.8H // ......................................................................*........................................... + smlal2 v30.4S, v10.8H, v4.8H // .......................................................................*.......................................... + smlal v13.4S, v10.4H, v4.4H // ........................................................................*......................................... + smlal v12.4S, v10.4H, v8.4H // .........................................................................*........................................ + ldr q19, [x2, #-16] // .........................................*........................................................................ + uzp1 v23.8H, v13.8H, v30.8H // ...........................................................................*...................................... + uzp1 v17.8H, v12.8H, v22.8H // ....................................................................................*............................. + mul v23.8H, v23.8H, v2.8H // .............................................................................*.................................... + uzp2 v21.8H, v31.8H, v19.8H // ................................................................................*................................. + uzp1 v19.8H, v31.8H, v19.8H // ...................................................................................*.............................. + mul v17.8H, v17.8H, v2.8H // .....................................................................................*............................ + smlal v13.4S, v23.4H, v0.4H // .................................................................................*................................ + smlal2 v30.4S, v23.8H, v0.8H // ..................................................................................*............................... + ldr q23, [x5], #32 // .............................................*.................................................................... + smlal2 v22.4S, v17.8H, v0.8H // ...........................................................................................................*...... + uzp2 v15.8H, v13.8H, v30.8H // ......................................................................................*........................... + smlal v12.4S, v17.4H, v0.4H // ............................................................................................................*..... + ldr q17, [x5, #-16] // ................................................*................................................................. + ldr q13, [x1, #16] // ......................................................*........................................................... + uzp2 v27.8H, v23.8H, v17.8H // ....................................................*............................................................. + uzp1 v28.8H, v23.8H, v17.8H // ............................................................................*..................................... + uzp2 v7.8H, v12.8H, v22.8H // ...............................................................................................................*.. + ldr q23, [x1], #32 // ..........................................................................*....................................... + zip1 v5.8H, v7.8H, v15.8H // .................................................................................................................* + ldr q3, [x7, #16] // ........................................................................................*......................... + uzp1 v31.8H, v23.8H, v13.8H // ..............................................................................*................................... + uzp2 v16.8H, v23.8H, v13.8H // ...............................................................................*.................................. + smull2 v24.4S, v31.8H, v21.8H // .......................................................................................*.......................... + ldr q6, [x8, #16] // .........................................................................................*........................ + ldr q23, [x10], #32 // ..........................................................................................*....................... + smlal2 v24.4S, v16.8H, v19.8H // ..........................................................................................................*....... + ldr q17, [x10, #-16] // ...........................................................................................*...................... + ld1 {v22.8H}, [x12], #16 // ............................................................................................*..................... + uzp1 v30.8H, v23.8H, v17.8H // .............................................................................................*.................... + uzp2 v11.8H, v23.8H, v17.8H // ..............................................................................................*................... + ldr q23, [x4], #32 // ...............................................................................................*.................. + ldr q17, [x4, #-16] // ................................................................................................*................. + ldr q4, [x7], #32 // .................................................................................................*................ + uzp1 v20.8H, v23.8H, v17.8H // ..................................................................................................*............... + uzp2 v26.8H, v23.8H, v17.8H // ...................................................................................................*.............. + uzp1 v9.8H, v4.8H, v3.8H // .....................................................................................................*............ + smlal2 v24.4S, v20.8H, v27.8H // ..............................................................................................................*... + ld1 {v8.8H}, [x6], #16 // ....................................................................................................*............. + ldr q25, [x11, #16] // ......................................................................................................*........... + ldr q29, [x11], #32 // .......................................................................................................*.......... + ld1 {v12.8H}, [x9], #16 // ........................................................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // ................................................................................................................*. + ldr q14, [x8], #32 // .........................................................................................................*........ + ld1 {v23.8H}, [x3], #16 // .............................................................................................................*.... + + // ------------------------------------------------- new position --------------------------------------------------> + // 0 25 50 75 100 + // |------------------------|------------------------|------------------------|------------------------|------------- + // ldr q3, [x2], #32 // .*................................................................................................................ + // ldr q17, [x2, #-16] // *................................................................................................................. + // ldr q21, [x5], #32 // ..*............................................................................................................... + // ldr q19, [x5, #-16] // .....*............................................................................................................ + // uzp2 v27.8H, v21.8H, v19.8H // .......*.......................................................................................................... + // ldr q25, [x1, #16] // ......*........................................................................................................... + // ldr q22, [x1], #32 // .........*........................................................................................................ + // uzp1 v28.8H, v21.8H, v19.8H // ........*......................................................................................................... + // uzp1 v31.8H, v22.8H, v25.8H // ...........*...................................................................................................... + // uzp2 v16.8H, v22.8H, v25.8H // ............*..................................................................................................... + // uzp2 v21.8H, v3.8H, v17.8H // ...*.............................................................................................................. + // uzp1 v19.8H, v3.8H, v17.8H // ....*............................................................................................................. + // smull2 v24.4S, v31.8H, v21.8H // .............*.................................................................................................... + // ldr q3, [x7, #16] // ..........*....................................................................................................... + // ldr q6, [x8, #16] // .................................*................................................................................ + // ldr q8, [x10], #32 // ....................................*............................................................................. + // ldr q26, [x10, #-16] // .....................................*............................................................................ + // ld1 {v22.8H}, [x12], #16 // ......................................*........................................................................... + // uzp1 v30.8H, v8.8H, v26.8H // .......................................*.......................................................................... + // uzp2 v11.8H, v8.8H, v26.8H // ........................................*......................................................................... + // ldr q8, [x4], #32 // ...................*.............................................................................................. + // ldr q26, [x4, #-16] // ....................*............................................................................................. + // ldr q4, [x7], #32 // ................................*................................................................................. + // uzp1 v20.8H, v8.8H, v26.8H // ......................*........................................................................................... + // uzp2 v26.8H, v8.8H, v26.8H // .......................*.......................................................................................... + // ld1 {v8.8H}, [x6], #16 // .........................................*........................................................................ + // uzp1 v9.8H, v4.8H, v3.8H // ..................................*............................................................................... + // ldr q25, [x11, #16] // ..........................................*....................................................................... + // ldr q29, [x11], #32 // .............................................*.................................................................... + // ld1 {v12.8H}, [x9], #16 // ..............................................*................................................................... + // ldr q14, [x8], #32 // .................................................*................................................................ + // smlal2 v24.4S, v16.8H, v19.8H // .................*................................................................................................ + // ld1 {v23.8H}, [x3], #16 // .....................*............................................................................................ + // smlal2 v24.4S, v20.8H, v27.8H // ........................*......................................................................................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................*.................................................................. + // smlal2 v24.4S, v26.8H, v28.8H // ............................*..................................................................................... + // uzp2 v4.8H, v4.8H, v3.8H // ...................................*.............................................................................. + // smull2 v13.4S, v31.8H, v19.8H // ...............*.................................................................................................. + // ldr q3, [x2], #32 // ..................................................*............................................................... + // uzp2 v1.8H, v29.8H, v25.8H // ................................................*................................................................. + // smlal2 v13.4S, v16.8H, v23.8H // ..........................*....................................................................................... + // ldr q17, [x2, #-16] // .....................................................................*............................................ + // smull v18.4S, v31.4H, v19.4H // ................*................................................................................................. + // smlal2 v13.4S, v20.8H, v28.8H // ..............................*................................................................................... + // smull v29.4S, v31.4H, v21.4H // ..............*................................................................................................... + // ldr q21, [x5], #32 // ..............................................................................*................................... + // smlal2 v13.4S, v26.8H, v8.8H // ...........................................*...................................................................... + // smlal v29.4S, v16.4H, v19.4H // ..................*............................................................................................... + // ldr q19, [x5, #-16] // ..................................................................................*............................... + // smlal v18.4S, v16.4H, v23.4H // ...........................*...................................................................................... + // smlal v29.4S, v20.4H, v27.4H // .........................*........................................................................................ + // uzp1 v31.8H, v14.8H, v6.8H // ...................................................*.............................................................. + // uzp2 v27.8H, v21.8H, v19.8H // ....................................................................................*............................. + // smlal v18.4S, v20.4H, v28.4H // ...............................*.................................................................................. + // ldr q25, [x1, #16] // ...................................................................................*.............................. + // smlal v29.4S, v26.4H, v28.4H // .............................*.................................................................................... + // smlal v18.4S, v26.4H, v8.4H // ............................................*..................................................................... + // uzp2 v26.8H, v14.8H, v6.8H // ....................................................*............................................................. + // smlal2 v13.4S, v9.8H, v31.8H // .....................................................*............................................................ + // smlal2 v24.4S, v9.8H, v26.8H // ......................................................*........................................................... + // smlal v29.4S, v9.4H, v26.4H // .......................................................*.......................................................... + // smlal v18.4S, v9.4H, v31.4H // ........................................................*......................................................... + // smlal2 v13.4S, v4.8H, v12.8H // .........................................................*........................................................ + // smlal2 v24.4S, v4.8H, v31.8H // ..........................................................*....................................................... + // smlal v29.4S, v4.4H, v31.4H // ...........................................................*...................................................... + // smlal v18.4S, v4.4H, v12.4H // ............................................................*..................................................... + // smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................................................... + // smlal2 v24.4S, v30.8H, v1.8H // ..............................................................*................................................... + // smlal v29.4S, v30.4H, v1.4H // ...............................................................*.................................................. + // smlal v18.4S, v30.4H, v10.4H // ................................................................*................................................. + // smlal2 v13.4S, v11.8H, v22.8H // .................................................................*................................................ + // smlal2 v24.4S, v11.8H, v10.8H // ..................................................................*............................................... + // smlal v29.4S, v11.4H, v10.4H // ...................................................................*.............................................. + // smlal v18.4S, v11.4H, v22.4H // ....................................................................*............................................. + // ldr q22, [x1], #32 // .......................................................................................*.......................... + // uzp1 v31.8H, v29.8H, v24.8H // ......................................................................*........................................... + // uzp1 v28.8H, v21.8H, v19.8H // .....................................................................................*............................ + // mul v19.8H, v31.8H, v2.8H // ........................................................................*......................................... + // uzp1 v31.8H, v22.8H, v25.8H // ..........................................................................................*....................... + // uzp2 v16.8H, v22.8H, v25.8H // ...........................................................................................*...................... + // uzp2 v21.8H, v3.8H, v17.8H // .........................................................................*........................................ + // smlal v29.4S, v19.4H, v0.4H // ............................................................................*..................................... + // smlal2 v24.4S, v19.8H, v0.8H // .............................................................................*.................................... + // uzp1 v19.8H, v3.8H, v17.8H // ..........................................................................*....................................... + // uzp1 v26.8H, v18.8H, v13.8H // .......................................................................*.......................................... + // mul v23.8H, v26.8H, v2.8H // ...........................................................................*...................................... + // uzp2 v15.8H, v29.8H, v24.8H // ................................................................................*................................. + // smull2 v24.4S, v31.8H, v21.8H // ............................................................................................*..................... + // ldr q3, [x7, #16] // .........................................................................................*........................ + // ldr q6, [x8, #16] // .............................................................................................*.................... + // ldr q8, [x10], #32 // ..............................................................................................*................... + // ldr q26, [x10, #-16] // ................................................................................................*................. + // ld1 {v22.8H}, [x12], #16 // .................................................................................................*................ + // uzp1 v30.8H, v8.8H, v26.8H // ..................................................................................................*............... + // uzp2 v11.8H, v8.8H, v26.8H // ...................................................................................................*.............. + // ldr q8, [x4], #32 // ....................................................................................................*............. + // ldr q26, [x4, #-16] // .....................................................................................................*............ + // ldr q4, [x7], #32 // ......................................................................................................*........... + // uzp1 v20.8H, v8.8H, v26.8H // .......................................................................................................*.......... + // uzp2 v26.8H, v8.8H, v26.8H // ........................................................................................................*......... + // ld1 {v8.8H}, [x6], #16 // ...........................................................................................................*...... + // uzp1 v9.8H, v4.8H, v3.8H // .........................................................................................................*........ + // ldr q25, [x11, #16] // ............................................................................................................*..... + // ldr q29, [x11], #32 // .............................................................................................................*.... + // ld1 {v12.8H}, [x9], #16 // ..............................................................................................................*... + // ldr q14, [x8], #32 // ................................................................................................................*. + // smlal2 v24.4S, v16.8H, v19.8H // ...............................................................................................*.................. + // smlal2 v13.4S, v23.8H, v0.8H // ...............................................................................*.................................. + // smlal v18.4S, v23.4H, v0.4H // .................................................................................*................................ + // ld1 {v23.8H}, [x3], #16 // .................................................................................................................* + // smlal2 v24.4S, v20.8H, v27.8H // ..........................................................................................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // ......................................................................................*........................... + // uzp1 v10.8H, v29.8H, v25.8H // ...............................................................................................................*.. + // zip1 v5.8H, v7.8H, v15.8H // ........................................................................................*......................... + + sub count, count, #2 +1: + // Instructions: 82 + // Expected cycles: 102 + // Expected IPC: 0.80 + // + // Cycle bound: 102.0 + // IPC bound: 0.80 + // + // Wall time: 15.93s + // User time: 15.93s + // + // ------------------------------- original position -------------------------------> + // 0 25 50 75 + // |------------------------|------------------------|------------------------|------ + smlal2 v24.4S, v26.8H, v28.8H // .................................*................................................ + uzp2 v4.8H, v4.8H, v3.8H // .....................................*............................................ + smull2 v13.4S, v31.8H, v19.8H // ..........*....................................................................... + ldr q3, [x2], #32 // ....e............................................................................. + uzp2 v1.8H, v29.8H, v25.8H // ..........................................................*....................... + smlal2 v13.4S, v16.8H, v23.8H // ............*..................................................................... + ldr q17, [x2, #-16] // .....e............................................................................ + smull v18.4S, v31.4H, v19.4H // .........*........................................................................ + smlal2 v13.4S, v20.8H, v28.8H // ...........................*...................................................... + smull v29.4S, v31.4H, v21.4H // .............*.................................................................... + ldr q21, [x5], #32 // .....................e............................................................ + smlal2 v13.4S, v26.8H, v8.8H // .............................*.................................................... + smlal v29.4S, v16.4H, v19.4H // ...............*.................................................................. + ldr q19, [x5, #-16] // ......................e........................................................... + smlal v18.4S, v16.4H, v23.4H // ...........*...................................................................... + smlal v29.4S, v20.4H, v27.4H // ..............................*................................................... + uzp1 v31.8H, v14.8H, v6.8H // ........................................*......................................... + uzp2 v27.8H, v21.8H, v19.8H // ........................e......................................................... + smlal v18.4S, v20.4H, v28.4H // ..........................*....................................................... + ldr q25, [x1, #16] // .e................................................................................ + smlal v29.4S, v26.4H, v28.4H // ................................*................................................. + smlal v18.4S, v26.4H, v8.4H // ............................*..................................................... + uzp2 v26.8H, v14.8H, v6.8H // .........................................*........................................ + smlal2 v13.4S, v9.8H, v31.8H // ............................................*..................................... + smlal2 v24.4S, v9.8H, v26.8H // ................................................*................................. + smlal v29.4S, v9.4H, v26.4H // ...............................................*.................................. + smlal v18.4S, v9.4H, v31.4H // ...........................................*...................................... + smlal2 v13.4S, v4.8H, v12.8H // ..............................................*................................... + smlal2 v24.4S, v4.8H, v31.8H // ..................................................*............................... + smlal v29.4S, v4.4H, v31.4H // .................................................*................................ + smlal v18.4S, v4.4H, v12.4H // .............................................*.................................... + smlal2 v13.4S, v30.8H, v10.8H // .............................................................*.................... + smlal2 v24.4S, v30.8H, v1.8H // .................................................................*................ + smlal v29.4S, v30.4H, v1.4H // ................................................................*................. + smlal v18.4S, v30.4H, v10.4H // ............................................................*..................... + smlal2 v13.4S, v11.8H, v22.8H // ...............................................................*.................. + smlal2 v24.4S, v11.8H, v10.8H // ...................................................................*.............. + smlal v29.4S, v11.4H, v10.4H // ..................................................................*............... + smlal v18.4S, v11.4H, v22.4H // ..............................................................*................... + ldr q22, [x1], #32 // e................................................................................. + uzp1 v31.8H, v29.8H, v24.8H // .........................................................................*........ + uzp1 v28.8H, v21.8H, v19.8H // .......................e.......................................................... + mul v19.8H, v31.8H, v2.8H // ..........................................................................*....... + uzp1 v31.8H, v22.8H, v25.8H // ..e............................................................................... + uzp2 v16.8H, v22.8H, v25.8H // ...e.............................................................................. + uzp2 v21.8H, v3.8H, v17.8H // .......e.......................................................................... + smlal v29.4S, v19.4H, v0.4H // ...........................................................................*...... + smlal2 v24.4S, v19.8H, v0.8H // ............................................................................*..... + uzp1 v19.8H, v3.8H, v17.8H // ......e........................................................................... + uzp1 v26.8H, v18.8H, v13.8H // ....................................................................*............. + zip2 v14.8H, v7.8H, v15.8H // ...............................................................................l.. + mul v23.8H, v26.8H, v2.8H // .....................................................................*............ + uzp2 v15.8H, v29.8H, v24.8H // .............................................................................*.... + smull2 v24.4S, v31.8H, v21.8H // ..............e................................................................... + str q14, [x0, #16] // .................................................................................l + ldr q3, [x7, #16] // ...................................e.............................................. + ldr q6, [x8, #16] // .......................................e.......................................... + ldr q8, [x10], #32 // ...................................................e.............................. + ldr q26, [x10, #-16] // ....................................................e............................. + ld1 {v22.8H}, [x12], #16 // ...........................................................e...................... + uzp1 v30.8H, v8.8H, v26.8H // .....................................................e............................ + uzp2 v11.8H, v8.8H, v26.8H // ......................................................e........................... + ldr q8, [x4], #32 // .................e................................................................ + ldr q26, [x4, #-16] // ..................e............................................................... + ldr q4, [x7], #32 // ..................................e............................................... + uzp1 v20.8H, v8.8H, v26.8H // ...................e.............................................................. + uzp2 v26.8H, v8.8H, v26.8H // ....................e............................................................. + ld1 {v8.8H}, [x6], #16 // .........................e........................................................ + uzp1 v9.8H, v4.8H, v3.8H // ....................................e............................................. + ldr q25, [x11, #16] // ........................................................e......................... + ldr q29, [x11], #32 // .......................................................e.......................... + ld1 {v12.8H}, [x9], #16 // ..........................................e....................................... + ldr q14, [x8], #32 // ......................................e........................................... + smlal2 v24.4S, v16.8H, v19.8H // ................e................................................................. + smlal2 v13.4S, v23.8H, v0.8H // .......................................................................*.......... + smlal v18.4S, v23.4H, v0.4H // ......................................................................*........... + ld1 {v23.8H}, [x3], #16 // ........e......................................................................... + smlal2 v24.4S, v20.8H, v27.8H // ...............................e.................................................. + uzp2 v7.8H, v18.8H, v13.8H // ........................................................................*......... + uzp1 v10.8H, v29.8H, v25.8H // .........................................................e........................ + str q5, [x0], #32 // ................................................................................l. + zip1 v5.8H, v7.8H, v15.8H // ..............................................................................*... + + // ----------------------------------------------------------------------------------------------------------------- new position ------------------------------------------------------------------------------------------------------------------> + // 0 25 50 75 100 125 150 175 200 225 + // |------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|------------------------|---------------- + // ldr q12, [x1], #32 // ....................................e..........................................'......................................~..........................................'......................................~......................................... + // ldr q13, [x1, #-16] // ................e..............................................................'..................~..............................................................'..................~............................................................. + // uzp1 v3.8h, v12.8h, v13.8h // ........................................e......................................'..........................................~......................................'..........................................~..................................... + // uzp2 v4.8h, v12.8h, v13.8h // .........................................e.....................................'...........................................~.....................................'...........................................~.................................... + // ldr q12, [x2], #32 // e..............................................................................'..~..............................................................................'..~............................................................................. + // ldr q13, [x2, #-16] // ...e...........................................................................'.....~...........................................................................'.....~.......................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // .............................................e.................................'...............................................~.................................'...............................................~................................ + // uzp2 v6.8h, v12.8h, v13.8h // ..........................................e....................................'............................................~....................................'............................................~................................... + // ld1 {v7.8h}, [x3], #16 // .........................................................................e.....'...........................................................................~.....'...........................................................................~.... + // smull v8.4s, v3.4h, v5.4h // ....~..........................................................................'......*..........................................................................'......~......................................................................... + // smull2 v10.4s, v3.8h, v5.8h // ...............................................................................'.*...............................................................................'.~.............................................................................. + // smlal v8.4s, v4.4h, v7.4h // ...........~...................................................................'.............*...................................................................'.............~.................................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ..~............................................................................'....*............................................................................'....~........................................................................... + // smull v9.4s, v3.4h, v6.4h // ......~........................................................................'........*........................................................................'........~....................................................................... + // smull2 v11.4s, v3.8h, v6.8h // ..................................................e............................'....................................................~............................'....................................................~........................... + // smlal v9.4s, v4.4h, v5.4h // .........~.....................................................................'...........*.....................................................................'...........~.................................................................... + // smlal2 v11.4s, v4.8h, v5.8h // ......................................................................e........'........................................................................~........'........................................................................~....... + // ldr q12, [x4], #32 // ...........................................................e...................'.............................................................~...................'.............................................................~.................. + // ldr q13, [x4, #-16] // ............................................................e..................'..............................................................~..................'..............................................................~................. + // uzp1 v3.8h, v12.8h, v13.8h // ..............................................................e................'................................................................~................'................................................................~............... + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................e...............'.................................................................~...............'.................................................................~.............. + // ldr q12, [x5], #32 // .......e.......................................................................'.........~.......................................................................'.........~...................................................................... + // ldr q13, [x5, #-16] // ..........e....................................................................'............~....................................................................'............~................................................................... + // uzp1 v5.8h, v12.8h, v13.8h // ......................................e........................................'........................................~........................................'........................................~....................................... + // uzp2 v6.8h, v12.8h, v13.8h // ..............e................................................................'................~................................................................'................~............................................................... + // ld1 {v7.8h}, [x6], #16 // ................................................................e..............'..................................................................~..............'..................................................................~............. + // smlal v8.4s, v3.4h, v5.4h // ...............~...............................................................'.................*...............................................................'.................~.............................................................. + // smlal2 v10.4s, v3.8h, v5.8h // .....~.........................................................................'.......*.........................................................................'.......~........................................................................ + // smlal v8.4s, v4.4h, v7.4h // ..................~............................................................'....................*............................................................'....................~........................................................... + // smlal2 v10.4s, v4.8h, v7.8h // ........~......................................................................'..........*......................................................................'..........~..................................................................... + // smlal v9.4s, v3.4h, v6.4h // ............~..................................................................'..............*..................................................................'..............~................................................................. + // smlal2 v11.4s, v3.8h, v6.8h // ..........................................................................e....'............................................................................~....'............................................................................~... + // smlal v9.4s, v4.4h, v5.4h // .................~.............................................................'...................*.............................................................'...................~............................................................ + // smlal2 v11.4s, v4.8h, v5.8h // ...............................................................................*.................................................................................~................................................................................ + // ldr q12, [x7], #32 // .............................................................e.................'...............................................................~.................'...............................................................~................ + // ldr q13, [x7, #-16] // ....................................................e..........................'......................................................~..........................'......................................................~......................... + // uzp1 v3.8h, v12.8h, v13.8h // .................................................................e.............'...................................................................~.............'...................................................................~............ + // uzp2 v4.8h, v12.8h, v13.8h // ...............................................................................'*................................................................................'~............................................................................... + // ldr q12, [x8], #32 // .....................................................................e.........'.......................................................................~.........'.......................................................................~........ + // ldr q13, [x8, #-16] // .....................................................e.........................'.......................................................~.........................'.......................................................~........................ + // uzp1 v5.8h, v12.8h, v13.8h // .............~.................................................................'...............*.................................................................'...............~................................................................ + // uzp2 v6.8h, v12.8h, v13.8h // ...................~...........................................................'.....................*...........................................................'.....................~.......................................................... + // ld1 {v7.8h}, [x9], #16 // ....................................................................e..........'......................................................................~..........'......................................................................~......... + // smlal v8.4s, v3.4h, v5.4h // .......................~.......................................................'.........................*.......................................................'.........................~...................................................... + // smlal2 v10.4s, v3.8h, v5.8h // ....................~..........................................................'......................*..........................................................'......................~......................................................... + // smlal v8.4s, v4.4h, v7.4h // ...........................~...................................................'.............................*...................................................'.............................~.................................................. + // smlal2 v10.4s, v4.8h, v7.8h // ........................~......................................................'..........................*......................................................'..........................~..................................................... + // smlal v9.4s, v3.4h, v6.4h // ......................~........................................................'........................*........................................................'........................~....................................................... + // smlal2 v11.4s, v3.8h, v6.8h // .....................~.........................................................'.......................*.........................................................'.......................~........................................................ + // smlal v9.4s, v4.4h, v5.4h // ..........................~....................................................'............................*....................................................'............................~................................................... + // smlal2 v11.4s, v4.8h, v5.8h // .........................~.....................................................'...........................*.....................................................'...........................~.................................................... + // ldr q12, [x10], #32 // ......................................................e........................'........................................................~........................'........................................................~....................... + // ldr q13, [x10, #-16] // .......................................................e.......................'.........................................................~.......................'.........................................................~...................... + // uzp1 v3.8h, v12.8h, v13.8h // .........................................................e.....................'...........................................................~.....................'...........................................................~.................... + // uzp2 v4.8h, v12.8h, v13.8h // ..........................................................e....................'............................................................~....................'............................................................~................... + // ldr q12, [x11], #32 // ...................................................................e...........'.....................................................................~...........'.....................................................................~.......... + // ldr q13, [x11, #-16] // ..................................................................e............'....................................................................~............'....................................................................~........... + // uzp1 v5.8h, v12.8h, v13.8h // ............................................................................e..'..............................................................................~..'..............................................................................~. + // uzp2 v6.8h, v12.8h, v13.8h // .~.............................................................................'...*.............................................................................'...~............................................................................ + // ld1 {v7.8h}, [x12], #16 // ........................................................e......................'..........................................................~......................'..........................................................~..................... + // smlal v8.4s, v3.4h, v5.4h // ...............................~...............................................'.................................*...............................................'.................................~.............................................. + // smlal2 v10.4s, v3.8h, v5.8h // ............................~..................................................'..............................*..................................................'..............................~................................................. + // smlal v8.4s, v4.4h, v7.4h // ...................................~...........................................'.....................................*...........................................'.....................................~.......................................... + // smlal2 v10.4s, v4.8h, v7.8h // ................................~..............................................'..................................*..............................................'..................................~............................................. + // smlal v9.4s, v3.4h, v6.4h // ..............................~................................................'................................*................................................'................................~............................................... + // smlal2 v11.4s, v3.8h, v6.8h // .............................~.................................................'...............................*.................................................'...............................~................................................ + // smlal v9.4s, v4.4h, v5.4h // ..................................~............................................'....................................*............................................'....................................~........................................... + // smlal2 v11.4s, v4.8h, v5.8h // .................................~.............................................'...................................*.............................................'...................................~............................................ + // uzp1 v28.8h, v8.8h, v10.8h // ..............................................~................................'................................................*................................'................................................~............................... + // mul v28.8h, v28.8h, v2.8h // ................................................~..............................'..................................................*..............................'..................................................~............................. + // smlal v8.4s, v28.4h, v0.4h // ........................................................................~......'..........................................................................*......'..........................................................................~..... + // smlal2 v10.4s, v28.8h, v0.8h // .......................................................................~.......'.........................................................................*.......'.........................................................................~...... + // uzp2 v26.8h, v8.8h, v10.8h // ...........................................................................~...'.............................................................................*...'.............................................................................~.. + // uzp1 v28.8h, v9.8h, v11.8h // .....................................~.........................................'.......................................*.........................................'.......................................~........................................ + // mul v28.8h, v28.8h, v2.8h // .......................................~.......................................'.........................................*.......................................'.........................................~...................................... + // smlal v9.4s, v28.4h, v0.4h // ...........................................~...................................'.............................................*...................................'.............................................~.................................. + // smlal2 v11.4s, v28.8h, v0.8h // ............................................~..................................'..............................................*..................................'..............................................~................................. + // uzp2 v27.8h, v9.8h, v11.8h // .................................................~.............................'...................................................*.............................'...................................................~............................ + // zip1 v12.8h, v26.8h, v27.8h // ..............................................................................~'................................................................................*'................................................................................ + // zip2 v13.8h, v26.8h, v27.8h // ...............................................~...............................'.................................................~...............................'.................................................l.............................. + // str q12, [x0], #32 // .............................................................................~.'...............................................................................~.'...............................................................................l + // str q13, [x0, #-16] // ...................................................~...........................'.....................................................~...........................'.....................................................l.......................... + + sub count, count, #1 + cbnz count, 1b + // Instructions: 50 + // Expected cycles: 56 + // Expected IPC: 0.89 + // + // Cycle bound: 56.0 + // IPC bound: 0.89 + // + // Wall time: 4.16s + // User time: 4.16s + // + // --------------- original position ---------------> + // 0 25 + // |------------------------| + smull2 v17.4S, v31.8H, v19.8H // ..*............................................... + uzp2 v1.8H, v14.8H, v6.8H // ................*................................. + smull v18.4S, v31.4H, v21.4H // .......*.......................................... + smlal2 v24.4S, v26.8H, v28.8H // *................................................. + smlal2 v17.4S, v16.8H, v23.8H // ....*............................................. + smull v21.4S, v31.4H, v19.4H // .....*............................................ + smlal v18.4S, v16.4H, v19.4H // .........*........................................ + uzp2 v31.8H, v4.8H, v3.8H // .*................................................ + uzp1 v3.8H, v14.8H, v6.8H // ............*..................................... + smlal v21.4S, v16.4H, v23.4H // ..........*....................................... + smlal v18.4S, v20.4H, v27.4H // ...........*...................................... + uzp2 v14.8H, v29.8H, v25.8H // ...*.............................................. + smlal2 v17.4S, v20.8H, v28.8H // ......*........................................... + smlal v21.4S, v20.4H, v28.4H // .............*.................................... + smlal v18.4S, v26.4H, v28.4H // ..............*................................... + smlal2 v24.4S, v9.8H, v1.8H // ..................*............................... + smlal2 v17.4S, v26.8H, v8.8H // ........*......................................... + smlal v21.4S, v26.4H, v8.4H // ...............*.................................. + smlal v18.4S, v9.4H, v1.4H // ...................*.............................. + smlal2 v24.4S, v31.8H, v3.8H // ......................*........................... + smlal2 v17.4S, v9.8H, v3.8H // .................*................................ + smlal v21.4S, v9.4H, v3.4H // ....................*............................. + smlal v18.4S, v31.4H, v3.4H // .......................*.......................... + smlal2 v24.4S, v30.8H, v14.8H // ..........................*....................... + smlal2 v17.4S, v31.8H, v12.8H // .....................*............................ + smlal v21.4S, v31.4H, v12.4H // ........................*......................... + smlal v18.4S, v30.4H, v14.4H // ...........................*...................... + smlal2 v24.4S, v11.8H, v10.8H // ..............................*................... + smlal2 v17.4S, v30.8H, v10.8H // .........................*........................ + smlal v21.4S, v30.4H, v10.4H // ............................*..................... + smlal v18.4S, v11.4H, v10.4H // ...............................*.................. + zip2 v19.8H, v7.8H, v15.8H // ......................................*........... + smlal2 v17.4S, v11.8H, v22.8H // .............................*.................... + smlal v21.4S, v11.4H, v22.4H // ................................*................. + uzp1 v23.8H, v18.8H, v24.8H // .................................*................ + str q19, [x0, #16] // .........................................*........ + mul v19.8H, v23.8H, v2.8H // ..................................*............... + uzp1 v23.8H, v21.8H, v17.8H // .....................................*............ + str q5, [x0], #32 // .............................................*.... + mul v26.8H, v23.8H, v2.8H // .......................................*.......... + smlal v18.4S, v19.4H, v0.4H // ...................................*.............. + smlal2 v24.4S, v19.8H, v0.8H // ....................................*............. + smlal v21.4S, v26.4H, v0.4H // ...........................................*...... + smlal2 v17.4S, v26.8H, v0.8H // ..........................................*....... + uzp2 v13.8H, v18.8H, v24.8H // ........................................*......... + uzp2 v19.8H, v21.8H, v17.8H // ............................................*..... + zip1 v23.8H, v19.8H, v13.8H // ..............................................*... + zip2 v19.8H, v19.8H, v13.8H // ...............................................*.. + str q23, [x0], #32 // .................................................* + str q19, [x0, #-16] // ................................................*. + + // ----------------- new position ------------------> + // 0 25 + // |------------------------|------------------------ + // smlal2 v24.4S, v26.8H, v28.8H // ...*.............................................. + // uzp2 v4.8H, v4.8H, v3.8H // .......*.......................................... + // smull2 v13.4S, v31.8H, v19.8H // *................................................. + // uzp2 v1.8H, v29.8H, v25.8H // ...........*...................................... + // smlal2 v13.4S, v16.8H, v23.8H // ....*............................................. + // smull v18.4S, v31.4H, v19.4H // .....*............................................ + // smlal2 v13.4S, v20.8H, v28.8H // ............*..................................... + // smull v29.4S, v31.4H, v21.4H // ..*............................................... + // smlal2 v13.4S, v26.8H, v8.8H // ................*................................. + // smlal v29.4S, v16.4H, v19.4H // ......*........................................... + // smlal v18.4S, v16.4H, v23.4H // .........*........................................ + // smlal v29.4S, v20.4H, v27.4H // ..........*....................................... + // uzp1 v31.8H, v14.8H, v6.8H // ........*......................................... + // smlal v18.4S, v20.4H, v28.4H // .............*.................................... + // smlal v29.4S, v26.4H, v28.4H // ..............*................................... + // smlal v18.4S, v26.4H, v8.4H // .................*................................ + // uzp2 v26.8H, v14.8H, v6.8H // .*................................................ + // smlal2 v13.4S, v9.8H, v31.8H // ....................*............................. + // smlal2 v24.4S, v9.8H, v26.8H // ...............*.................................. + // smlal v29.4S, v9.4H, v26.4H // ..................*............................... + // smlal v18.4S, v9.4H, v31.4H // .....................*............................ + // smlal2 v13.4S, v4.8H, v12.8H // ........................*......................... + // smlal2 v24.4S, v4.8H, v31.8H // ...................*.............................. + // smlal v29.4S, v4.4H, v31.4H // ......................*........................... + // smlal v18.4S, v4.4H, v12.4H // .........................*........................ + // smlal2 v13.4S, v30.8H, v10.8H // ............................*..................... + // smlal2 v24.4S, v30.8H, v1.8H // .......................*.......................... + // smlal v29.4S, v30.4H, v1.4H // ..........................*....................... + // smlal v18.4S, v30.4H, v10.4H // .............................*.................... + // smlal2 v13.4S, v11.8H, v22.8H // ................................*................. + // smlal2 v24.4S, v11.8H, v10.8H // ...........................*...................... + // smlal v29.4S, v11.4H, v10.4H // ..............................*................... + // smlal v18.4S, v11.4H, v22.4H // .................................*................ + // uzp1 v31.8H, v29.8H, v24.8H // ..................................*............... + // mul v19.8H, v31.8H, v2.8H // ....................................*............. + // smlal v29.4S, v19.4H, v0.4H // ........................................*......... + // smlal2 v24.4S, v19.8H, v0.8H // .........................................*........ + // uzp1 v26.8H, v18.8H, v13.8H // .....................................*............ + // zip2 v14.8H, v7.8H, v15.8H // ...............................*.................. + // mul v23.8H, v26.8H, v2.8H // .......................................*.......... + // uzp2 v15.8H, v29.8H, v24.8H // ............................................*..... + // str q14, [x0, #16] // ...................................*.............. + // smlal2 v13.4S, v23.8H, v0.8H // ...........................................*...... + // smlal v18.4S, v23.4H, v0.4H // ..........................................*....... + // uzp2 v7.8H, v18.8H, v13.8H // .............................................*.... + // str q5, [x0], #32 // ......................................*........... + // zip1 v5.8H, v7.8H, v15.8H // ..............................................*... + // zip2 v14.8H, v7.8H, v15.8H // ...............................................*.. + // str q14, [x0, #16] // .................................................* + // str q5, [x0], #32 // ................................................*. + + + pop_stack + ret +#endif /* MLKEM_K == 4 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_asm_clean.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_asm_clean.S new file mode 100644 index 0000000000..722dc0f49e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_asm_clean.S @@ -0,0 +1,341 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/************************************************* + * Name: rej_uniform_asm_clean + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer of MLKEM_N + * 16-bit coefficients. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * Must be a multiple of 24. + * + * Returns number of sampled 16-bit integers (at most MLKEM_N). + **************************************************/ +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +// We save the output on the stack first, and copy to the actual +// output buffer only in the end. This is because the main loop can overwrite +// by up to 62 bytes, which we account for here (we use 64 bytes for alignment). +#define STACK_SIZE (2*MLKEM_N + 64) +#define STACK_OFFSET_TMP_OUTPUT 0 + +.macro push_stack + sub sp, sp, #STACK_SIZE +.endm + +.macro pop_stack + add sp, sp, #STACK_SIZE +.endm + + /* Parameters */ + output .req x0 + buf .req x1 + buflen .req w2 + table_idx .req x3 + + len .req w4 + + /* Temporary output on the stack */ + output_tmp .req x7 + output_tmp_base .req x8 + + /* Number of coefficients sampled so far */ + count .req w9 + buf_consumed .req w10 + + /* Temporary registers */ + tmp .req w11 + final_copy_count .req w11 + + rec_idx_0 .req w12 + rec_idx_1 .req w13 + rec_idx_2 .req w14 + rec_idx_3 .req w15 + + ctr0 .req w12 + ctr1 .req w13 + ctr2 .req w14 + ctr3 .req w15 + + ctr01 .req ctr0 + ctr23 .req ctr2 + + /* Vector registers */ + + buf0 .req v0 + buf1 .req v1 + buf2 .req v2 + + tmp0 .req v4 + tmp1 .req v5 + tmp2 .req v6 + tmp3 .req v7 + + sign0 .req v4 + sign1 .req v5 + sign2 .req v6 + sign3 .req v7 + + val0 .req v16 + val0q .req q16 + val1 .req v17 + val1q .req q17 + val2 .req v18 + val2q .req q18 + val3 .req v19 + val3q .req q19 + + t0 .req s20 + t1 .req s21 + t2 .req s22 + t3 .req s23 + + table0 .req v24 + table0q .req q24 + table1 .req v25 + table1q .req q25 + table2 .req v26 + table2q .req q26 + table3 .req v27 + table3q .req q27 + + mlkem_q .req v30 + bits .req v31 + bits_q .req q31 + +.text +/* Literal pool */ +.p2align 4 +c_bit_table: + .short 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80 + +.align 4 +.global MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean) +MLKEM_ASM_NAMESPACE(rej_uniform_asm_clean): + push_stack + + ldr bits_q, c_bit_table + movz tmp, #MLKEM_Q + dup mlkem_q.8h, tmp + + add output_tmp_base, sp, #STACK_OFFSET_TMP_OUTPUT + mov output_tmp, output_tmp_base + + mov count, #0 + mov len, #MLKEM_N + + cmp buflen, #48 + b.lo loop48_end + +loop48: + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + // First, we unpack the byte stream into a stream of signed + // coefficients, interpreting each consecutive 3 bytes as two + // signed 12-bit coefficients, presented as 16-bit integers. + // + // We handle 16 such triples a time, and use ld3 for the required + // de-interleaving of the byte stream. + sub buflen, buflen, #48 + ld3 {buf0.16b, buf1.16b, buf2.16b}, [buf], #48 + + // Unpack 16 triples of bytes into 16 pairs of 16-bit integers, + // represented as 4 vectors val0-val3. + zip1 tmp0.16b, buf0.16b, buf1.16b + zip2 tmp1.16b, buf0.16b, buf1.16b + zip1 tmp2.16b, buf1.16b, buf2.16b + zip2 tmp3.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + bic tmp1.8h, #0xf0, lsl 8 + ushr tmp2.8h, tmp2.8h, #4 + ushr tmp3.8h, tmp3.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp2.8h + zip2 val1.8h, tmp0.8h, tmp2.8h + zip1 val2.8h, tmp1.8h, tmp3.8h + zip2 val3.8h, tmp1.8h, tmp3.8h + + // At this point, val0-val3 are the signed integers to do rejection + // sampling on. For each of them, do the following: + // - Check which coefficients are within range, and represent the set + // of lane-indices of those coefficients as an 8-bit bitmap. + // - Move the respective lanes to the front of the vector. This is the + // most complex part, and is done by interpreting the 8-bit bitmap as + // an index into a lookup table giving the lane-table to be use for + // the `tbl` instruction. + // - Write the vector to the output buffer, but merely increase the output + // buffer pointer by the number of valid coefficients. + + // Set valid lanes to -1 (0b1...1) + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + cmhi sign2.8h, mlkem_q.8h, val2.8h + cmhi sign3.8h, mlkem_q.8h, val3.8h + + // If lane i is valid and has value -1, retain only i-th bit + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + and sign2.16b, sign2.16b, bits.16b + and sign3.16b, sign3.16b, bits.16b + + // Get 8-bit bitmap of valid lane indices by adding lanes + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + fmov rec_idx_2, t2 + fmov rec_idx_3, t3 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + ldr table2q, [table_idx, rec_idx_2, uxtw #4] + ldr table3q, [table_idx, rec_idx_3, uxtw #4] + + // Compute number of valid coefficients. Recall that at this + // point, lane i has value 2^i (hence popcount 1) if its coefficient + // is valid, and 0 otherwise. + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + cnt sign2.16b, sign2.16b + cnt sign3.16b, sign3.16b + + // Extract number of valid coefficients + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + uaddlv t2, sign2.8h + uaddlv t3, sign3.8h + + fmov ctr0, t0 + fmov ctr1, t1 + fmov ctr2, t2 + fmov ctr3, t3 + + // Move valid coefficients to the front + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + tbl val2.16b, {val2.16b}, table2.16b + tbl val3.16b, {val3.16b}, table3.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + str val2q, [output_tmp] + add output_tmp, output_tmp, ctr2, uxtw #1 + + str val3q, [output_tmp] + add output_tmp, output_tmp, ctr3, uxtw #1 + + add ctr01, ctr0, ctr1 + add ctr23, ctr2, ctr3 + add count, count, ctr01 + add count, count, ctr23 + + cmp buflen, #48 + b.hs loop48 +loop48_end: + + // Finish once we've generated sufficiently many coefficients + cmp count, len + b.hs memory_copy + + cmp buflen, #24 + b.lo memory_copy + + sub buflen, buflen, #24 + ld3 {buf0.8b, buf1.8b, buf2.8b}, [buf], #24 + + zip1 tmp0.16b, buf0.16b, buf1.16b + zip1 tmp1.16b, buf1.16b, buf2.16b + + bic tmp0.8h, #0xf0, lsl 8 + ushr tmp1.8h, tmp1.8h, #4 + + zip1 val0.8h, tmp0.8h, tmp1.8h + zip2 val1.8h, tmp0.8h, tmp1.8h + + cmhi sign0.8h, mlkem_q.8h, val0.8h + cmhi sign1.8h, mlkem_q.8h, val1.8h + + and sign0.16b, sign0.16b, bits.16b + and sign1.16b, sign1.16b, bits.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov rec_idx_0, t0 + fmov rec_idx_1, t1 + + ldr table0q, [table_idx, rec_idx_0, uxtw #4] + ldr table1q, [table_idx, rec_idx_1, uxtw #4] + + cnt sign0.16b, sign0.16b + cnt sign1.16b, sign1.16b + + uaddlv t0, sign0.8h + uaddlv t1, sign1.8h + + fmov ctr0, t0 + fmov ctr1, t1 + + tbl val0.16b, {val0.16b}, table0.16b + tbl val1.16b, {val1.16b}, table1.16b + + str val0q, [output_tmp] + add output_tmp, output_tmp, ctr0, uxtw #1 + + str val1q, [output_tmp] + add output_tmp, output_tmp, ctr1, uxtw #1 + + add count, count, ctr0 + add count, count, ctr1 + +memory_copy: + // min = min(count,len) + cmp count, len + csel count, count, len, lo + + // Always copy MLKEM_N coefficients from the stack to the destination, + // even if not all of them may be valid. This simplifies the loop and + // allows us to stick to vectorized code. + mov final_copy_count, #0 + mov output_tmp, output_tmp_base +final_copy: + ldr val0q, [output_tmp], #64 + ldr val1q, [output_tmp, #-48] + ldr val2q, [output_tmp, #-32] + ldr val3q, [output_tmp, #-16] + str val0q, [output], #64 + str val1q, [output, #-48] + str val2q, [output, #-32] + str val3q, [output, #-16] + add final_copy_count, final_copy_count, #32 + cmp final_copy_count, #MLKEM_N + b.lt final_copy + + mov w0, count + b return + +return: + pop_stack + ret + +#endif /* defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_table.c new file mode 100644 index 0000000000..507660349d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/aarch64/src/rej_uniform_table.c @@ -0,0 +1,288 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_CLEAN) || \ + defined(MLKEM_NATIVE_ARITH_BACKEND_AARCH64_OPT) + +#include +#include "arith_native_aarch64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 0 */, + 0, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 1 */, + 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 2 */, + 0, 1, 2, 3, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 3 */, + 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 4 */, + 0, 1, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 5 */, + 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 6 */, + 0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 7 */, + 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 8 */, + 0, 1, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 9 */, + 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 10 */, + 0, 1, 2, 3, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 11 */, + 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 12 */, + 0, 1, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 13 */, + 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 14 */, + 0, 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1 /* 15 */, + 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 16 */, + 0, 1, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 17 */, + 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 18 */, + 0, 1, 2, 3, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 19 */, + 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 20 */, + 0, 1, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 21 */, + 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 22 */, + 0, 1, 2, 3, 4, 5, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 23 */, + 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 24 */, + 0, 1, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 25 */, + 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 26 */, + 0, 1, 2, 3, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 27 */, + 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 28 */, + 0, 1, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 29 */, + 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1 /* 30 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1 /* 31 */, + 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 32 */, + 0, 1, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 33 */, + 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 34 */, + 0, 1, 2, 3, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 35 */, + 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 36 */, + 0, 1, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 37 */, + 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 38 */, + 0, 1, 2, 3, 4, 5, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 39 */, + 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 40 */, + 0, 1, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 41 */, + 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 42 */, + 0, 1, 2, 3, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 43 */, + 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 44 */, + 0, 1, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 45 */, + 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 46 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, -1, -1, -1, -1, -1, -1 /* 47 */, + 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 48 */, + 0, 1, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 49 */, + 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 50 */, + 0, 1, 2, 3, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 51 */, + 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 52 */, + 0, 1, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 53 */, + 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 54 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 55 */, + 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 56 */, + 0, 1, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 57 */, + 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 58 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 59 */, + 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1, -1, -1 /* 60 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 61 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1, -1, -1 /* 62 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, -1, -1, -1, -1 /* 63 */, + 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 64 */, + 0, 1, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 65 */, + 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 66 */, + 0, 1, 2, 3, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 67 */, + 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 68 */, + 0, 1, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 69 */, + 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 70 */, + 0, 1, 2, 3, 4, 5, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 71 */, + 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 72 */, + 0, 1, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 73 */, + 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 74 */, + 0, 1, 2, 3, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 75 */, + 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 76 */, + 0, 1, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 77 */, + 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 78 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, -1, -1, -1, -1, -1, -1 /* 79 */, + 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 80 */, + 0, 1, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 81 */, + 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 82 */, + 0, 1, 2, 3, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 83 */, + 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 84 */, + 0, 1, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 85 */, + 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 86 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 87 */, + 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 88 */, + 0, 1, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 89 */, + 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 90 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 91 */, + 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 92 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 93 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1, -1, -1 /* 94 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, -1, -1, -1, -1 /* 95 */, + 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 96 */, + 0, 1, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 97 */, + 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 98 */, + 0, 1, 2, 3, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 99 */, + 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 100 */, + 0, 1, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 101 */, + 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 102 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 103 */, + 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 104 */, + 0, 1, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 105 */, + 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 106 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 107 */, + 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 108 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 109 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 110 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, -1, -1, -1, -1 /* 111 */, + 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 112 */, + 0, 1, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 113 */, + 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 114 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 115 */, + 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 116 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 117 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 118 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 119 */, + 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1, -1, -1 /* 120 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 121 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 122 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 123 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1, -1, -1 /* 124 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 125 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1, -1, -1 /* 126 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, -1, -1 /* 127 */, + 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 128 */, + 0, 1, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 129 */, + 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 130 */, + 0, 1, 2, 3, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 131 */, + 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 132 */, + 0, 1, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 133 */, + 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 134 */, + 0, 1, 2, 3, 4, 5, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 135 */, + 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 136 */, + 0, 1, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 137 */, + 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 138 */, + 0, 1, 2, 3, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 139 */, + 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 140 */, + 0, 1, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 141 */, + 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 142 */, + 0, 1, 2, 3, 4, 5, 6, 7, 14, 15, -1, -1, -1, -1, -1, -1 /* 143 */, + 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 144 */, + 0, 1, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 145 */, + 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 146 */, + 0, 1, 2, 3, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 147 */, + 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 148 */, + 0, 1, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 149 */, + 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 150 */, + 0, 1, 2, 3, 4, 5, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 151 */, + 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 152 */, + 0, 1, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 153 */, + 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 154 */, + 0, 1, 2, 3, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 155 */, + 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 156 */, + 0, 1, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 157 */, + 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1, -1, -1 /* 158 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, -1, -1, -1, -1 /* 159 */, + 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 160 */, + 0, 1, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 161 */, + 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 162 */, + 0, 1, 2, 3, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 163 */, + 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 164 */, + 0, 1, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 165 */, + 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 166 */, + 0, 1, 2, 3, 4, 5, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 167 */, + 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 168 */, + 0, 1, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 169 */, + 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 170 */, + 0, 1, 2, 3, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 171 */, + 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 172 */, + 0, 1, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 173 */, + 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 174 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 14, 15, -1, -1, -1, -1 /* 175 */, + 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 176 */, + 0, 1, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 177 */, + 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 178 */, + 0, 1, 2, 3, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 179 */, + 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 180 */, + 0, 1, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 181 */, + 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 182 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 183 */, + 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 184 */, + 0, 1, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 185 */, + 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 186 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 187 */, + 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1, -1, -1 /* 188 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 189 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1, -1, -1 /* 190 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 14, 15, -1, -1 /* 191 */, + 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 192 */, + 0, 1, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 193 */, + 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 194 */, + 0, 1, 2, 3, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 195 */, + 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 196 */, + 0, 1, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 197 */, + 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 198 */, + 0, 1, 2, 3, 4, 5, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 199 */, + 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 200 */, + 0, 1, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 201 */, + 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 202 */, + 0, 1, 2, 3, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 203 */, + 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 204 */, + 0, 1, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 205 */, + 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 206 */, + 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, -1, -1, -1, -1 /* 207 */, + 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 208 */, + 0, 1, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 209 */, + 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 210 */, + 0, 1, 2, 3, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 211 */, + 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 212 */, + 0, 1, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 213 */, + 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 214 */, + 0, 1, 2, 3, 4, 5, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 215 */, + 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 216 */, + 0, 1, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 217 */, + 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 218 */, + 0, 1, 2, 3, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 219 */, + 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 220 */, + 0, 1, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 221 */, + 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1, -1, -1 /* 222 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, 15, -1, -1 /* 223 */, + 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 /* 224 */, + 0, 1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 225 */, + 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 226 */, + 0, 1, 2, 3, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 227 */, + 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 228 */, + 0, 1, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 229 */, + 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 230 */, + 0, 1, 2, 3, 4, 5, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 231 */, + 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 232 */, + 0, 1, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 233 */, + 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 234 */, + 0, 1, 2, 3, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 235 */, + 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 236 */, + 0, 1, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 237 */, + 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 238 */, + 0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13, 14, 15, -1, -1 /* 239 */, + 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1 /* 240 */, + 0, 1, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 241 */, + 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 242 */, + 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 243 */, + 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 244 */, + 0, 1, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 245 */, + 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 246 */, + 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 247 */, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1 /* 248 */, + 0, 1, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 249 */, + 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 250 */, + 0, 1, 2, 3, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 251 */, + 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1 /* 252 */, + 0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 253 */, + 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, -1, -1 /* 254 */, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 /* 255 */, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_aarch64_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_aarch64_rej_uniform_table) +int empty_cu_aarch64_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_aarch64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/LICENSE new file mode 100644 index 0000000000..7922ab8007 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/LICENSE @@ -0,0 +1,6 @@ +Public Domain (https://creativecommons.org/share-your-work/public-domain/cc0/); +or Apache 2.0 License (https://www.apache.org/licenses/LICENSE-2.0.html). + +For Keccak and AES we are using public-domain +code from sources and by authors listed in +comments on top of the respective files. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_ref/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/LICENSE b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/LICENSE new file mode 100644 index 0000000000..7922ab8007 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/LICENSE @@ -0,0 +1,6 @@ +Public Domain (https://creativecommons.org/share-your-work/public-domain/cc0/); +or Apache 2.0 License (https://www.apache.org/licenses/LICENSE-2.0.html). + +For Keccak and AES we are using public-domain +code from sources and by authors listed in +comments on top of the respective files. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/api.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/api.h new file mode 100644 index 0000000000..792ecb8a4a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/api.h @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Native arithmetic interface + * + * This header is primarily for documentation purposes. + * It should not be included by backend implementations. + * + * To ensure consistency with backends, the header will be + * included automatically after inclusion of the active + * backend, to ensure consistency of function signatures, + * and run sanity checks. + */ +#ifdef MLKEM_NATIVE_ARITH_NATIVE_API_H +#error \ + "The arithmetic backend API `mlkem/native/api.h` " \ + "should not be directly included. Please include the relevant " \ + "structure headers directly." +#else /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ +#define MLKEM_NATIVE_ARITH_NATIVE_API_H + +#include +#include "poly.h" +#include "polyvec.h" + +/* + * This is the C<->native interface allowing for the drop-in of + * native code for performance critical arithmetic components of ML-KEM. + * + * A _backend_ is a specific implementation of (part of) this interface. + * + * To add a function to a backend, define MLKEM_USE_NATIVE_XXX and + * implement `static inline xxx(...)` in the profile header. + * + * The only exception is MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER. This option can + * be set if there are native implementations for all of NTT, invNTT, and + * base multiplication, and allows the native implementation to use a + * custom order of polynomial coefficients in NTT domain -- the use of such + * custom order is not an implementation-detail since the public matrix + * is generated in NTT domain. In this case, a permutation function + * poly_permute_bitrev_to_custom() needs to be provided that permutes + * polynomials in NTT domain from bitreversed to the custom order. + */ + +/* + * Those functions are meant to be trivial wrappers around the chosen native + * implementation. The are static inline to avoid unnecessary calls. + * The macro before each declaration controls whether a native + * implementation is present. + */ + +#if defined(MLKEM_USE_NATIVE_NTT) +/************************************************* + * Name: ntt_native + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input polynomial is assumed to be in normal order. + * The output polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +static INLINE void ntt_native(poly *); +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* + * This must only be set if NTT, invNTT, basemul, mulcache, and + * to/from byte stream conversions all have native implementations + * that are adapted to the custom order. + */ +#if !defined(MLKEM_USE_NATIVE_NTT) || !defined(MLKEM_USE_NATIVE_INTT) || \ + !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) || \ + !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) || \ + !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) || \ + !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +#error \ + "Invalid native profile: MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER can only be \ +set if there are native implementations for NTT, invNTT, mulcache, basemul, \ +and to/from bytes conversions." +#endif + +/************************************************* + * Name: poly_permute_bitrev_to_custom + * + * Description: When MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is defined, + * convert a polynomial in NTT domain from bitreversed + * order to the custom order output by the native NTT. + * + * This must only be defined if there is native code for + * all of (a) NTT, (b) invNTT, (c) basemul, (d) mulcache. + * Arguments: - poly *p: pointer to in/output polynomial + * + **************************************************/ +static INLINE void poly_permute_bitrev_to_custom(poly *); +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +#if defined(MLKEM_USE_NATIVE_INTT) +/************************************************* + * Name: intt_native + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place. + * + * The input polynomial is in bitreversed order, or of a + * custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * The output polynomial is assumed to be in normal order. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +static INLINE void intt_native(poly *); +#endif /* MLKEM_USE_NATIVE_INTT */ + +#if defined(MLKEM_USE_NATIVE_POLY_REDUCE) +/************************************************* + * Name: poly_reduce_native + * + * Description: Applies modular reduction to all coefficients of a polynomial. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_reduce_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +#if defined(MLKEM_USE_NATIVE_POLY_TOMONT) +/************************************************* + * Name: poly_tomont_native + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +static INLINE void poly_tomont_native(poly *); +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication cache for a polynomial + * in NTT domain. + * + * The purpose of the multiplication cache is to + * cache repeated computations required during a + * base multiplication of polynomials in NTT domain. + * The structure of the multiplication-cache is + * implementation defined. + * + * Arguments: INPUT: + * - poly: const pointer to input polynomial. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * OUTPUT + * - cache: pointer to multiplication cache + **************************************************/ +static INLINE void poly_mulcache_compute_native(poly_mulcache *cache, + const poly *poly); +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +/************************************************* + * Name: poly_mulcache_compute_native + * + * Description: Compute multiplication of polynomials in NTT domain. + * + * Arguments: INPUT: + * - a: First polynomial operand. + * This must be in NTT domain and inin bitreversed order, or of + * a custom order if MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER is set. + * See the documentation of MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + * for more information. + * - b: Second polynomial operand. + * As for a. + * - b_cache: Multiplication-cache for b. + * OUTPUT + * - r: Result of the base multiplication. This is again + * in NTT domain, and of the same order as a and b. + **************************************************/ +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); +#endif + +#if defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +/************************************************* + * Name: poly_tobytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range -Q+1 .. Q-1 + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a); +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +/************************************************* + * Name: poly_frombytes_native + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - r: pointer to output polynomial in NTT domain + * OUTPUT + * - a: const pointer to input byte aray + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +static INLINE void poly_frombytes_native(poly *a, + const uint8_t r[MLKEM_POLYBYTES]); +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +#if defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +/************************************************* + * Name: rej_uniform_native + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int len: requested number of 16-bit integers + * (uniform mod q). + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes. + * + * Return -1 if the native implementation does not support the input lengths. + * Otherwise, returns non-negative number of sampled 16-bit integers (at most + * len). + **************************************************/ +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen); +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +#endif /* MLKEM_NATIVE_ARITH_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/arith_backend.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/arith_backend.h new file mode 100644 index 0000000000..09e30f207a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/arith_backend.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#if !defined(MLKEM_NATIVE_ARITH_IMPL_H) +#define MLKEM_NATIVE_ARITH_IMPL_H + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_IMPL) +#include MLKEM_NATIVE_ARITH_BACKEND_IMPL + +/* Include to enforce consistency of API and implementation, + * and conduct sanity checks on the backend. + * + * Keep this _after_ the inclusion of the backend; otherwise, + * the sanity checks won't have an effect. */ +#include "api.h" +#endif + +#endif /* MLKEM_NATIVE_ARITH_IMPL_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.c new file mode 100644 index 0000000000..433bdc954b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "cbd.h" +#include + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define load32_littleendian MLKEM_NAMESPACE(load32_littleendian) +#define load24_littleendian MLKEM_NAMESPACE(load24_littleendian) +#define cbd2 MLKEM_NAMESPACE(cbd2) +#define cbd3 MLKEM_NAMESPACE(cbd3) +/* End of static namespacing */ + +/************************************************* + * Name: load32_littleendian + * + * Description: load 4 bytes into a 32-bit integer + * in little-endian order + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x + **************************************************/ +static uint32_t load32_littleendian(const uint8_t x[4]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + r |= (uint32_t)x[3] << 24; + return r; +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: load24_littleendian + * + * Description: load 3 bytes into a 32-bit integer + * in little-endian order. + * This function is only needed for ML-KEM-512 + * + * Arguments: - const uint8_t *x: pointer to input byte array + * + * Returns 32-bit unsigned integer loaded from x (most significant byte is zero) + **************************************************/ +static uint32_t load24_littleendian(const uint8_t x[3]) +{ + uint32_t r; + r = (uint32_t)x[0]; + r |= (uint32_t)x[1] << 8; + r |= (uint32_t)x[2] << 16; + return r; +} +#endif /* MLKEM_ETA1 == 3 */ + +/************************************************* + * Name: cbd2 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd2(poly *r, const uint8_t buf[2 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i, 3))) + { + unsigned j; + uint32_t t = load32_littleendian(buf + 4 * i); + uint32_t d = t & 0x55555555; + d += (t >> 1) & 0x55555555; + + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_abs_bound(r->coeffs, 0, 8 * i + j, 3))) + { + const int16_t a = (d >> (4 * j + 0)) & 0x3; + const int16_t b = (d >> (4 * j + 2)) & 0x3; + r->coeffs[8 * i + j] = a - b; + } + } +} + +#if MLKEM_ETA1 == 3 +/************************************************* + * Name: cbd3 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter eta=3. + * This function is only needed for ML-KEM-512 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +static void cbd3(poly *r, const uint8_t buf[3 * MLKEM_N / 4]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 4))) + { + unsigned j; + const uint32_t t = load24_littleendian(buf + 3 * i); + uint32_t d = t & 0x00249249; + d += (t >> 1) & 0x00249249; + d += (t >> 2) & 0x00249249; + + for (j = 0; j < 4; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 4 && j >= 0 && j <= 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i + j, 4))) + { + const int16_t a = (d >> (6 * j + 0)) & 0x7; + const int16_t b = (d >> (6 * j + 3)) & 0x7; + r->coeffs[4 * i + j] = a - b; + } + } +} +#endif /* MLKEM_ETA1 == 3 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +{ +#if MLKEM_ETA1 == 2 + cbd2(r, buf); +#elif MLKEM_ETA1 == 3 + cbd3(r, buf); +#else +#error "This implementation requires eta1 in {2,3}" +#endif +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +{ +#if MLKEM_ETA2 == 2 + cbd2(r, buf); +#else +#error "This implementation requires eta2 = 2" +#endif +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.h new file mode 100644 index 0000000000..15db895708 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbd.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef CBD_H +#define CBD_H + +#include +#include "common.h" +#include "poly.h" + +#define poly_cbd_eta1 MLKEM_NAMESPACE(poly_cbd_eta1) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA1. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta1(poly *r, const uint8_t buf[MLKEM_ETA1 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA1 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA1 + 1)) +); + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_cbd_eta2 MLKEM_NAMESPACE(poly_cbd_eta2) +/************************************************* + * Name: poly_cbd_eta1 + * + * Description: Given an array of uniformly random bytes, compute + * polynomial with coefficients distributed according to + * a centered binomial distribution with parameter MLKEM_ETA2. + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *buf: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_cbd_eta2(poly *r, const uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(buf, MLKEM_ETA2 * MLKEM_N / 4)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbmc.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbmc.h new file mode 100644 index 0000000000..baa0bfa9fb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/cbmc.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/*************************************************** + * Basic replacements for __CPROVER_XXX contracts + ***************************************************/ + +#include "common.h" + +#ifndef CBMC + +#define __contract__(x) +#define __loop__(x) +#define cassert(x, y) + +#else /* CBMC _is_ defined, therefore we're doing proof */ + +#define __contract__(x) x +#define __loop__(x) x + +/* https://diffblue.github.io/cbmc/contracts-assigns.html */ +#define assigns(...) __CPROVER_assigns(__VA_ARGS__) + +/* https://diffblue.github.io/cbmc/contracts-requires-ensures.html */ +#define requires(...) __CPROVER_requires(__VA_ARGS__) +#define ensures(...) __CPROVER_ensures(__VA_ARGS__) +/* https://diffblue.github.io/cbmc/contracts-loops.html */ +#define invariant(...) __CPROVER_loop_invariant(__VA_ARGS__) +#define decreases(...) __CPROVER_decreases(__VA_ARGS__) +/* cassert to avoid confusion with in-built assert */ +#define cassert(...) __CPROVER_assert(__VA_ARGS__) +#define assume(...) __CPROVER_assume(__VA_ARGS__) + +/*************************************************** + * Macros for "expression" forms that may appear + * _inside_ top-level contracts. + ***************************************************/ + +/* + * function return value - useful inside ensures + * https://diffblue.github.io/cbmc/contracts-functions.html + */ +#define return_value (__CPROVER_return_value) + +/* + * assigns l-value targets + * https://diffblue.github.io/cbmc/contracts-assigns.html + */ +#define object_whole(...) __CPROVER_object_whole(__VA_ARGS__) +#define memory_slice(...) __CPROVER_object_upto(__VA_ARGS__) +#define same_object(...) __CPROVER_same_object(__VA_ARGS__) + +/* + * Pointer-related predicates + * https://diffblue.github.io/cbmc/contracts-memory-predicates.html + */ +#define memory_no_alias(...) __CPROVER_is_fresh(__VA_ARGS__) +#define readable(...) __CPROVER_r_ok(__VA_ARGS__) +#define writeable(...) __CPROVER_w_ok(__VA_ARGS__) + +/* + * History variables + * https://diffblue.github.io/cbmc/contracts-history-variables.html + */ +#define old(...) __CPROVER_old(__VA_ARGS__) +#define loop_entry(...) __CPROVER_loop_entry(__VA_ARGS__) + +/* + * Quantifiers + * Note that the range on qvar is _exclusive_ between qvar_lb .. qvar_ub + * https://diffblue.github.io/cbmc/contracts-quantifiers.html + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define forall(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> (predicate) \ + } + +#define EXISTS(qvar, qvar_lb, qvar_ub, predicate) \ + __CPROVER_exists \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) && (predicate) \ + } +/* clang-format on */ + +/*************************************************** + * Convenience macros for common contract patterns + ***************************************************/ + +/* + * Boolean-value predidate that asserts that "all values of array_var are in + * range value_lb (inclusive) .. value_ub (exclusive)" + * Example: + * array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q) + * expands to + * __CPROVER_forall { int k; (0 <= k && k <= MLKEM_N-1) ==> ( + * 0 <= a->coeffs[k]) && a->coeffs[k] < MLKEM_Q)) } + */ + +/* + * Prevent clang-format from corrupting CBMC's special ==> operator + */ +/* clang-format off */ +#define CBMC_CONCAT_(left, right) left##right +#define CBMC_CONCAT(left, right) CBMC_CONCAT_(left, right) + +#define array_bound_core(qvar, qvar_lb, qvar_ub, array_var, \ + value_lb, value_ub) \ + __CPROVER_forall \ + { \ + unsigned qvar; \ + ((qvar_lb) <= (qvar) && (qvar) < (qvar_ub)) ==> \ + (((value_lb) <= (array_var[(qvar)])) && \ + ((array_var[(qvar)]) < (value_ub))) \ + } + +#define array_bound(array_var, qvar_lb, qvar_ub, value_lb, value_ub) \ + array_bound_core(CBMC_CONCAT(_cbmc_idx, __LINE__), (qvar_lb), \ + (qvar_ub), (array_var), (value_lb), (value_ub)) +/* clang-format on */ + +/* Wrapper around array_bound operating on absolute values. + * + * Note that since the absolute bound is inclusive, but the lower + * bound in array_bound is inclusive, we have to raise it by 1. + */ +#define array_abs_bound(arr, lb, ub, k) \ + array_bound((arr), (lb), (ub), -(k) + 1, (k)) + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/common.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/common.h new file mode 100644 index 0000000000..da886780c3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/common.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_COMMON_H +#define MLKEM_NATIVE_COMMON_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#include "params.h" +#include "sys.h" + +/* Include backend metadata */ +#if defined(MLKEM_USE_NATIVE) +#if defined(MLKEM_NATIVE_ARITH_BACKEND) +#include MLKEM_NATIVE_ARITH_BACKEND +#endif +#if defined(MLKEM_NATIVE_FIPS202_BACKEND) +#include MLKEM_NATIVE_FIPS202_BACKEND +#endif +#endif + +#if !defined(MLKEM_NATIVE_ARITH_BACKEND_NAME) +#define MLKEM_NATIVE_ARITH_BACKEND_NAME C +#endif + +#if !defined(MLKEM_NATIVE_FIPS202_BACKEND_NAME) +#define MLKEM_NATIVE_FIPS202_BACKEND_NAME C +#endif + +/* For a monobuild (where all compilation units are merged into one), mark + * all non-public API as static since they don't need external linkage. */ +#if !defined(MLKEM_NATIVE_MONOBUILD) +#define MLKEM_NATIVE_INTERNAL_API +#else +#define MLKEM_NATIVE_INTERNAL_API static +#endif + +#define MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) x1##_##x2 +#define MLKEM_NATIVE_MAKE_NAMESPACE(x1, x2) MLKEM_NATIVE_MAKE_NAMESPACE_(x1, x2) + +#define FIPS202_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(FIPS202_NAMESPACE_PREFIX, s) + +#define MLKEM_NAMESPACE(s) \ + MLKEM_NATIVE_MAKE_NAMESPACE(MLKEM_NAMESPACE_PREFIX, s) + +/* On Apple platforms, we need to emit leading underscore + * in front of assembly symbols. We thus introducee a separate + * namespace wrapper for ASM symbols. */ +#if !defined(__APPLE__) +#define MLKEM_ASM_NAMESPACE(sym) MLKEM_NAMESPACE(sym) +#define FIPS202_ASM_NAMESPACE(sym) FIPS202_NAMESPACE(sym) +#else +#define PREFIX_UNDERSCORE_(sym) _##sym +#define PREFIX_UNDERSCORE(sym) PREFIX_UNDERSCORE_(sym) +#define MLKEM_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(MLKEM_NAMESPACE(sym)) +#define FIPS202_ASM_NAMESPACE(sym) PREFIX_UNDERSCORE(FIPS202_NAMESPACE(sym)) +#endif + +#endif /* MLKEM_NATIVE_COMMON_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/config.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/config.h new file mode 100644 index 0000000000..d1441835b0 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/config.h @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef MLKEM_NATIVE_CONFIG_H +#define MLKEM_NATIVE_CONFIG_H + +/****************************************************************************** + * Name: MLKEM_K + * + * Description: Determines the security level for ML-KEM + * - MLKEM_K=2 corresponds to ML-KEM-512 + * - MLKEM_K=3 corresponds to ML-KEM-768 + * - MLKEM_K=4 corresponds to ML-KEM-1024 + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#ifndef MLKEM_K +#define MLKEM_K 3 /* Change this for different security strengths */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_CONFIG_FILE + * + * Description: If defined, this is a header that will be included instead + * of this default configuration file mlkem/config.h. + * + * When you need to build mlkem-native in multiple configurations, + * using varying MLKEM_NATIVE_CONFIG_FILE can be more convenient + * then configuring everything through CFLAGS. + * + * To use, MLKEM_NATIVE_CONFIG_FILE _must_ be defined prior + * to the inclusion of any mlkem-native headers. For example, + * it can be set by passing `-DMLKEM_NATIVE_CONFIG_FILE="..."` + * on the command line. + * + *****************************************************************************/ +/* #define MLKEM_NATIVE_CONFIG_FILE "config.h" */ + +/****************************************************************************** + * Name: MLKEM_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_NAMESPACE_PREFIX) +#define MLKEM_NAMESPACE_PREFIX MLKEM_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: FIPS202_NAMESPACE + * + * Description: The prefix to use to namespace global symbols + * from mlkem/fips202/. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(FIPS202_NAMESPACE_PREFIX) +#define FIPS202_NAMESPACE_PREFIX FIPS202_DEFAULT_NAMESPACE_PREFIX +#endif + +/****************************************************************************** + * Name: MLKEM_USE_NATIVE + * + * Description: Determines whether a native backend should + * be used, if available. + * + * This can also be set using CFLAGS. + * + *****************************************************************************/ +#if !defined(MLKEM_USE_NATIVE) +/* #define MLKEM_USE_NATIVE */ +#endif + +/****************************************************************************** + * Name: MLKEM_NATIVE_ARITH_BACKEND + * + * Description: The arithmetic backend to use. + * + * This must be the filename of an arithmetic backend. + * See the existing backends for examples. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE) && !defined(MLKEM_NATIVE_ARITH_BACKEND) +#define MLKEM_NATIVE_ARITH_BACKEND "default.h" +#endif /* MLKEM_NATIVE_ARITH_BACKEND */ + +/****************************************************************************** + * Name: MLKEM_NATIVE_FIPS202_BACKEND + * + * Description: The FIPS-202 backend to use. + * + * This must be the filename of an FIPS-202 backend. + * + * This can be set using CFLAGS. + * + *****************************************************************************/ +#if defined(MLKEM_USE_NATIVE_FIPS202) && !defined(MLKEM_NATIVE_FIPS202_BACKEND) +#define MLKEM_NATIVE_FIPS202_BACKEND "native/default.h" +#endif /* MLKEM_NATIVE_FIPS202_BACKEND */ + +/************************* Config internals ********************************/ + +/* Default namespace + * + * Don't change this. If you need a different namespace, re-define + * MLKEM_NAMESPACE above instead, and remove the following. + */ + +/* + * The default FIPS202 namespace is + * + * PQCP_MLKEM_NATIVE_FIPS202__ + * + * e.g., PQCP_MLKEM_NATIVE_FIPS202_C_ + */ + +#define FIPS202_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_FIPS202 + +/* + * The default MLKEM namespace is + * + * PQCP_MLKEM_NATIVE_MLKEM__ + * + * e.g., PQCP_MLKEM_NATIVE_MLKEM512_AARCH64_OPT_ + */ + +#if MLKEM_K == 2 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM512 +#elif MLKEM_K == 3 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM768 +#elif MLKEM_K == 4 +#define MLKEM_DEFAULT_NAMESPACE_PREFIX PQCP_MLKEM_NATIVE_MLKEM1024 +#endif + +#endif /* MLkEM_NATIVE_CONFIG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.c new file mode 100644 index 0000000000..64294ebe13 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.c @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "../common.h" + +#if defined(MLKEM_DEBUG) + +#include +#include "debug.h" + +#define MLKEM_NATIVE_DEBUG_ERROR_HEADER "[ERROR:%s:%04d] " + +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) +{ + if (val == 0) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER "Assertion failed: %s (value %d)\n", + file, line, description, val); + exit(1); + } +} + +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive) +{ + int err = 0; + unsigned i; + for (i = 0; i < len; i++) + { + int16_t val = ptr[i]; + if (!(val > lower_bound_exclusive && val < upper_bound_exclusive)) + { + fprintf(stderr, + MLKEM_NATIVE_DEBUG_ERROR_HEADER + "%s, index %u, value %d out of bounds (%d,%d)\n", + file, line, description, i, (int)val, lower_bound_exclusive, + upper_bound_exclusive); + err = 1; + } + } + + if (err == 1) + exit(1); +} + +#else /* MLKEM_DEBUG */ + +#define empty_cu_debug MLKEM_NAMESPACE(empty_cu_debug) +int empty_cu_debug; + +#endif /* MLKEM_DEBUG */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.h new file mode 100644 index 0000000000..5ce320ea2e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/debug/debug.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_DEBUG_H +#define MLKEM_DEBUG_H + +#include "../common.h" + +#if defined(MLKEM_DEBUG) +#include +#include +#include + +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +#define mlkem_debug_assert MLKEM_NAMESPACE(mlkem_debug_assert) +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + +/************************************************* + * Name: mlkem_debug_check_bounds + * + * Description: Check whether values in an array of int16_t + * are within specified bounds. + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of check + * - ptr: Base of array to be checked + * - len: Number of int16_t in ptr + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound + **************************************************/ +#define mlkem_debug_check_bounds MLKEM_NAMESPACE(mlkem_debug_check_bounds) +void mlkem_debug_check_bounds(const char *file, int line, + const char *description, const int16_t *ptr, + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); + +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do \ + { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + +/* Check that all coefficients in array of int16_t's are non-negative + * and below an exclusive upper bound. + * + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * high_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -1, ((high_bound))); \ + } while (0) + +/* Check absolute bounds in array of int16_t's + * ptr: Base of array, expression of type int16_t* + * len: Number of int16_t in array + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \ + (len), -(abs_bound), (abs_bound)); \ + } while (0) + +/* Check absolute bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define POLY_BOUND_MSG(ptr, abs_bound, msg) \ + BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \ + msg) + +/* Check unsigned bounds on coefficients in polynomial or mulcache + * ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + * msg: Message to print on failure */ +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + UBOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (ubound), \ + msg) + +/* Check absolute bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLY_BOUND(ptr, abs_bound) \ + POLY_BOUND_MSG((ptr), (abs_bound), "poly absolute bound for " #ptr) + +/* Check unsigned bounds on coefficients in polynomial + * ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLY_UBOUND(ptr, ubound) \ + POLY_UBOUND_MSG((ptr), (ubound), "poly unsigned bound for " #ptr) + +/* Check absolute bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * abs_bound: Exclusive upper bound on absolute value to check */ +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \ + "polyvec absolute bound for " #ptr ".vec[i]"); \ + } while (0) + +/* Check unsigned bounds on coefficients in vector of polynomials + * ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check + * ubound: Exclusive upper bound on value to check. Inclusive lower bound is 0. + */ +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + unsigned _debug_polyvec_bound_idx; \ + for (_debug_polyvec_bound_idx = 0; _debug_polyvec_bound_idx < MLKEM_K; \ + _debug_polyvec_bound_idx++) \ + POLY_UBOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (ubound), \ + "polyvec unsigned bound for " #ptr ".vec[i]"); \ + } while (0) + +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + +/* Following AWS-LC to define a C99-compliant static assert */ +#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ + typedef struct \ + { \ + unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ + } MLKEM_CONCAT(MLKEM_NAMESPACE(static_assertion_), msg) \ + __attribute__((unused)); + +#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \ + MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix)) +#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix) +#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \ + MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix)) +#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error) + +#else /* MLKEM_DEBUG */ + +#define CASSERT(val, msg) \ + do \ + { \ + } while (0) +#define SCALAR_BOUND(val, abs_bound, msg) \ + do \ + { \ + } while (0) +#define BOUND(ptr, len, abs_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLYVEC_BOUND(ptr, abs_bound) \ + do \ + { \ + } while (0) +#define POLY_BOUND_MSG(ptr, ubound, abs_bound) \ + do \ + { \ + } while (0) +#define UBOUND(ptr, len, high_bound, msg) \ + do \ + { \ + } while (0) +#define POLY_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLYVEC_UBOUND(ptr, ubound) \ + do \ + { \ + } while (0) +#define POLY_UBOUND_MSG(ptr, ubound, msg) \ + do \ + { \ + } while (0) +#define STATIC_ASSERT(cond, error) + +#endif /* MLKEM_DEBUG */ + +#endif /* MLKEM_DEBUG_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/default.h new file mode 100644 index 0000000000..d1e41c52e5 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/default.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H +#define MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H + +/* + * Default arithmetic backend + */ +#include "sys.h" + +#ifdef SYS_AARCH64 +/* + * For AArch64, we currently we have one clean and one opt profile. + * We default to the opt profile. + * + * In the future, this may branch further depending on the microarchitecture. + */ +#include "aarch64/opt.h" +#endif /* SYS_AARCH64 */ + +#ifdef SYS_X86_64_AVX2 +/* + * For now, there's only one x86_64 profile, based on + * the AVX2 code from the Kyber repository. + * https://github.com/pq-crystals/kyber + */ +#include "x86_64/default.h" +#endif /* SYS_X86_64 */ + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_DEFAULT_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.c new file mode 100644 index 0000000000..4d3133e14d --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.c @@ -0,0 +1,559 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "indcpa.h" +#include +#include +#include +#include "fips202.h" +#include "fips202x4.h" +#include "indcpa.h" +#include "ntt.h" +#include "poly.h" +#include "polyvec.h" +#include "randombytes.h" +#include "rej_uniform.h" +#include "symmetric.h" + +#include "arith_backend.h" +#include "debug/debug.h" + +#include "cbmc.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define pack_pk MLKEM_NAMESPACE(pack_pk) +#define unpack_pk MLKEM_NAMESPACE(unpack_pk) +#define pack_sk MLKEM_NAMESPACE(pack_sk) +#define unpack_sk MLKEM_NAMESPACE(unpack_sk) +#define pack_ciphertext MLKEM_NAMESPACE(pack_ciphertext) +#define unpack_ciphertext MLKEM_NAMESPACE(unpack_ciphertext) +#define gen_matrix_entry_x4 MLKEM_NAMESPACE(gen_matrix_entry_x4) +#define gen_matrix_entry MLKEM_NAMESPACE(gen_matrix_entry) +#define matvec_mul MLKEM_NAMESPACE(matvec_mul) +/* End of static namespacing */ + +/************************************************* + * Name: pack_pk + * + * Description: Serialize the public key as concatenation of the + * serialized vector of polynomials pk + * and the public seed used to generate the matrix A. + * + * Arguments: uint8_t *r: pointer to the output serialized public key + * polyvec *pk: pointer to the input public-key polyvec. + * Must have coefficients within [0,..,q-1]. + * const uint8_t *seed: pointer to the input public seed + **************************************************/ +static void pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], polyvec *pk, + const uint8_t seed[MLKEM_SYMBYTES]) +{ + POLYVEC_BOUND(pk, MLKEM_Q); + polyvec_tobytes(r, pk); + memcpy(r + MLKEM_POLYVECBYTES, seed, MLKEM_SYMBYTES); +} + +/************************************************* + * Name: unpack_pk + * + * Description: De-serialize public key from a byte array; + * approximate inverse of pack_pk + * + * Arguments: - polyvec *pk: pointer to output public-key polynomial vector + * Coefficients will be normalized to [0,..,q-1]. + * - uint8_t *seed: pointer to output seed to generate matrix A + * - const uint8_t *packedpk: pointer to input serialized public + * key. + **************************************************/ +static void unpack_pk(polyvec *pk, uint8_t seed[MLKEM_SYMBYTES], + const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES]) +{ + polyvec_frombytes(pk, packedpk); + memcpy(seed, packedpk + MLKEM_POLYVECBYTES, MLKEM_SYMBYTES); + + /* NOTE: If a modulus check was conducted on the PK, we know at this + * point that the coefficients of `pk` are unsigned canonical. The + * specifications and proofs, however, do _not_ assume this, and instead + * work with the easily provable bound by 4096. */ +} + +/************************************************* + * Name: pack_sk + * + * Description: Serialize the secret key + * + * Arguments: - uint8_t *r: pointer to output serialized secret key + * - polyvec *sk: pointer to input vector of polynomials (secret + *key) + **************************************************/ +static void pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], polyvec *sk) +{ + POLYVEC_BOUND(sk, MLKEM_Q); + polyvec_tobytes(r, sk); +} + +/************************************************* + * Name: unpack_sk + * + * Description: De-serialize the secret key; inverse of pack_sk + * + * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret + * key) + * - const uint8_t *packedsk: pointer to input serialized secret + * key + **************************************************/ +static void unpack_sk(polyvec *sk, + const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec_frombytes(sk, packedsk); +} + +/************************************************* + * Name: pack_ciphertext + * + * Description: Serialize the ciphertext as concatenation of the + * compressed and serialized vector of polynomials b + * and the compressed and serialized polynomial v + * + * Arguments: uint8_t *r: pointer to the output serialized ciphertext + * poly *pk: pointer to the input vector of polynomials b + * poly *v: pointer to the input polynomial v + **************************************************/ +static void pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], polyvec *b, poly *v) +{ + polyvec_compress_du(r, b); + poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v); +} + +/************************************************* + * Name: unpack_ciphertext + * + * Description: De-serialize and decompress ciphertext from a byte array; + * approximate inverse of pack_ciphertext + * + * Arguments: - polyvec *b: pointer to the output vector of polynomials b + * - poly *v: pointer to the output polynomial v + * - const uint8_t *c: pointer to the input serialized ciphertext + **************************************************/ +static void unpack_ciphertext(polyvec *b, poly *v, + const uint8_t c[MLKEM_INDCPA_BYTES]) +{ + polyvec_decompress_du(b, c); + poly_decompress_dv(v, c + MLKEM_POLYVECCOMPRESSEDBYTES_DU); +} + +#ifndef MLKEM_GEN_MATRIX_NBLOCKS +#define MLKEM_GEN_MATRIX_NBLOCKS \ + ((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_RATE) / XOF_RATE) +#endif + +/* + * Generate four A matrix entries from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry_x4(poly *vec, uint8_t *seed[4]) +__contract__( + requires(memory_no_alias(vec, sizeof(poly) * 4)) + requires(memory_no_alias(seed, sizeof(uint8_t*) * 4)) + requires(memory_no_alias(seed[0], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[1], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[2], MLKEM_SYMBYTES + 2)) + requires(memory_no_alias(seed[3], MLKEM_SYMBYTES + 2)) + assigns(memory_slice(vec, sizeof(poly) * 4)) + ensures(array_bound(vec[0].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[1].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[2].coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + ensures(array_bound(vec[3].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + /* Temporary buffers for XOF output before rejection sampling */ + uint8_t buf0[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf1[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf2[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + uint8_t buf3[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + + /* Tracks the number of coefficients we have already sampled */ + unsigned int ctr[KECCAK_WAY]; + xof_x4_ctx statex; + unsigned int buflen; + + shake128x4_inc_init(&statex); + + /* seed is MLKEM_SYMBYTES + 2 bytes long, but padded to MLKEM_SYMBYTES + 16 */ + xof_x4_absorb(&statex, seed[0], seed[1], seed[2], seed[3], + MLKEM_SYMBYTES + 2); + + /* + * Initially, squeeze heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + * This should generate the matrix entries with high probability. + */ + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, MLKEM_GEN_MATRIX_NBLOCKS, + &statex); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, 0, buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, 0, buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, 0, buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, 0, buf3, buflen); + + /* + * So long as not all matrix entries have been generated, squeeze + * one more block a time until we're done. + */ + buflen = XOF_RATE; + while (ctr[0] < MLKEM_N || ctr[1] < MLKEM_N || ctr[2] < MLKEM_N || + ctr[3] < MLKEM_N) + __loop__( + assigns(ctr, statex, memory_slice(vec, sizeof(poly) * 4), object_whole(buf0), + object_whole(buf1), object_whole(buf2), object_whole(buf3)) + invariant(ctr[0] <= MLKEM_N && ctr[1] <= MLKEM_N) + invariant(ctr[2] <= MLKEM_N && ctr[3] <= MLKEM_N) + invariant(ctr[0] > 0 ==> array_bound(vec[0].coeffs, 0, ctr[0], 0, MLKEM_Q)) + invariant(ctr[1] > 0 ==> array_bound(vec[1].coeffs, 0, ctr[1], 0, MLKEM_Q)) + invariant(ctr[2] > 0 ==> array_bound(vec[2].coeffs, 0, ctr[2], 0, MLKEM_Q)) + invariant(ctr[3] > 0 ==> array_bound(vec[3].coeffs, 0, ctr[3], 0, MLKEM_Q))) + { + xof_x4_squeezeblocks(buf0, buf1, buf2, buf3, 1, &statex); + ctr[0] = rej_uniform(vec[0].coeffs, MLKEM_N, ctr[0], buf0, buflen); + ctr[1] = rej_uniform(vec[1].coeffs, MLKEM_N, ctr[1], buf1, buflen); + ctr[2] = rej_uniform(vec[2].coeffs, MLKEM_N, ctr[2], buf2, buflen); + ctr[3] = rej_uniform(vec[3].coeffs, MLKEM_N, ctr[3], buf3, buflen); + } + + xof_x4_release(&statex); +} + +/* + * Generate a single A matrix entry from a seed, using rejection + * sampling on the output of a XOF. + */ +static void gen_matrix_entry(poly *entry, uint8_t seed[MLKEM_SYMBYTES + 2]) +__contract__( + requires(memory_no_alias(entry, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES + 2)) + assigns(memory_slice(entry, sizeof(poly))) + ensures(array_bound(entry->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +{ + xof_ctx state; + uint8_t buf[MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE]; + unsigned int ctr, buflen; + + shake128_inc_init(&state); + xof_absorb(&state, seed, MLKEM_SYMBYTES + 2); + + /* Initially, squeeze + sample heuristic number of MLKEM_GEN_MATRIX_NBLOCKS. + */ + /* This should generate the matrix entry with high probability. */ + xof_squeezeblocks(buf, MLKEM_GEN_MATRIX_NBLOCKS, &state); + buflen = MLKEM_GEN_MATRIX_NBLOCKS * XOF_RATE; + ctr = rej_uniform(entry->coeffs, MLKEM_N, 0, buf, buflen); + + /* Squeeze + sample one more block a time until we're done */ + buflen = XOF_RATE; + while (ctr < MLKEM_N) + __loop__( + assigns(ctr, state, memory_slice(entry, sizeof(poly)), object_whole(buf)) + invariant(0 <= ctr && ctr <= MLKEM_N) + invariant(ctr > 0 ==> array_bound(entry->coeffs, 0, ctr, + 0, MLKEM_Q))) + { + xof_squeezeblocks(buf, 1, &state); + ctr = rej_uniform(entry->coeffs, MLKEM_N, ctr, buf, buflen); + } + + xof_release(&state); +} + +#if !defined(MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER) +/* This namespacing is not done at the top to avoid a naming conflict + * with native backends, which are currently not yet namespaced. */ +#define poly_permute_bitrev_to_custom \ + MLKEM_NAMESPACE(poly_permute_bitrev_to_custom) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +__contract__( + /* We don't specify that this should be a permutation, but only + * that it does not change the bound established at the end of gen_matrix. */ + requires(memory_no_alias(data, sizeof(poly))) + requires(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(data, sizeof(poly))) + ensures(array_bound(data->coeffs, 0, MLKEM_N, 0, MLKEM_Q))) { ((void)data); } +#endif /* MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER */ + +/* Not static for benchmarking */ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +{ + unsigned i, j; + /* + * We generate four separate seed arrays rather than a single one to work + * around limitations in CBMC function contracts dealing with disjoint slices + * of the same parent object. + */ + + ALIGN uint8_t seed0[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed1[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed2[MLKEM_SYMBYTES + 2]; + ALIGN uint8_t seed3[MLKEM_SYMBYTES + 2]; + uint8_t *seedxy[4]; + seedxy[0] = seed0; + seedxy[1] = seed1; + seedxy[2] = seed2; + seedxy[3] = seed3; + + for (j = 0; j < KECCAK_WAY; j++) + { + memcpy(seedxy[j], seed, MLKEM_SYMBYTES); + } + + for (i = 0; i < (MLKEM_K * MLKEM_K / KECCAK_WAY) * KECCAK_WAY; + i += KECCAK_WAY) + { + uint8_t x, y; + + for (j = 0; j < KECCAK_WAY; j++) + { + x = (i + j) / MLKEM_K; + y = (i + j) % MLKEM_K; + if (transposed) + { + seedxy[j][MLKEM_SYMBYTES + 0] = x; + seedxy[j][MLKEM_SYMBYTES + 1] = y; + } + else + { + seedxy[j][MLKEM_SYMBYTES + 0] = y; + seedxy[j][MLKEM_SYMBYTES + 1] = x; + } + } + + /* + * This call writes across polyvec boundaries for K=2 and K=3. + * This is intentional and safe. + */ + gen_matrix_entry_x4(&a[0].vec[0] + i, seedxy); + } + + /* For left over polynomial, we use single keccak. */ + if (i < MLKEM_K * MLKEM_K) + { + uint8_t x, y; + x = i / MLKEM_K; + y = i % MLKEM_K; + + if (transposed) + { + seed0[MLKEM_SYMBYTES + 0] = x; + seed0[MLKEM_SYMBYTES + 1] = y; + } + else + { + seed0[MLKEM_SYMBYTES + 0] = y; + seed0[MLKEM_SYMBYTES + 1] = x; + } + + gen_matrix_entry(&a[0].vec[0] + i, seed0); + i++; + } + + cassert(i == MLKEM_K * MLKEM_K, + "gen_matrix: failed to generate whole matrix"); + + /* + * The public matrix is generated in NTT domain. If the native backend + * uses a custom order in NTT domain, permute A accordingly. + */ + for (i = 0; i < MLKEM_K; i++) + { + for (j = 0; j < MLKEM_K; j++) + { + poly_permute_bitrev_to_custom(&a[i].vec[j]); + } + } +} + +/************************************************* + * Name: matvec_mul + * + * Description: Computes matrix-vector product in NTT domain, + * via Montgomery multiplication. + * + * Arguments: - polyvec *out: Pointer to output polynomial vector + * - polyvec a[MLKEM_K]: Input matrix. Must be in NTT domain + * and have coefficients of absolute value < 4096. + * - polyvec *v: Input polynomial vector. Must be in NTT domain. + * - polyvec *vc: Mulcache for v, computed via + * polyvec_mulcache_compute(). + **************************************************/ +static void matvec_mul(polyvec *out, const polyvec a[MLKEM_K], const polyvec *v, + const polyvec_mulcache *vc) +__contract__( + requires(memory_no_alias(out, sizeof(polyvec))) + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(v, sizeof(polyvec))) + requires(memory_no_alias(vc, sizeof(polyvec_mulcache))) + requires(forall(k0, 0, MLKEM_K, + forall(k1, 0, MLKEM_K, + array_bound(a[k0].vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)))) + assigns(object_whole(out))) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + __loop__( + assigns(i, object_whole(out)) + invariant(i >= 0 && i <= MLKEM_K)) + { + polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a[i], v, vc); + } +} + + + +STATIC_ASSERT(NTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_enc_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + const uint8_t *publicseed = buf; + const uint8_t *noiseseed = buf + MLKEM_SYMBYTES; + polyvec a[MLKEM_K], e, pkpv, skpv; + polyvec_mulcache skpv_cache; + + ALIGN uint8_t coins_with_domain_separator[MLKEM_SYMBYTES + 1]; + /* Concatenate coins with MLKEM_K for domain separation of security levels */ + memcpy(coins_with_domain_separator, coins, MLKEM_SYMBYTES); + coins_with_domain_separator[MLKEM_SYMBYTES] = MLKEM_K; + + hash_g(buf, coins_with_domain_separator, MLKEM_SYMBYTES + 1); + + gen_matrix(a, publicseed, 0 /* no transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, e.vec + 0, e.vec + 1, + noiseseed, 0, 1, 2, 3); +#elif MLKEM_K == 3 + /* + * Only the first three output buffers are needed. + * The laster parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 0, 1, 2, + 0xFF /* irrelevant */); + /* Same here */ + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, + pkpv.vec + 0 /* irrelevant */, noiseseed, 3, 4, 5, + 0xFF /* irrelevant */); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(skpv.vec + 0, skpv.vec + 1, skpv.vec + 2, skpv.vec + 3, + noiseseed, 0, 1, 2, 3); + poly_getnoise_eta1_4x(e.vec + 0, e.vec + 1, e.vec + 2, e.vec + 3, noiseseed, + 4, 5, 6, 7); +#endif + + polyvec_ntt(&skpv); + polyvec_ntt(&e); + + polyvec_mulcache_compute(&skpv_cache, &skpv); + matvec_mul(&pkpv, a, &skpv, &skpv_cache); + polyvec_tomont(&pkpv); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&pkpv, &e); + polyvec_reduce(&pkpv); + polyvec_reduce(&skpv); + + pack_sk(sk, &skpv); + pack_pk(pk, &pkpv, publicseed); +} + + +/* Check that the arithmetic in indcpa_enc() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA1 < INT16_MAX, indcpa_enc_bound_0) +STATIC_ASSERT(INVNTT_BOUND + MLKEM_ETA2 + MLKEM_Q < INT16_MAX, + indcpa_enc_bound_1) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t seed[MLKEM_SYMBYTES]; + polyvec sp, pkpv, ep, at[MLKEM_K], b; + poly v, k, epp; + polyvec_mulcache sp_cache; + + unpack_pk(&pkpv, seed, pk); + poly_frommsg(&k, m); + gen_matrix(at, seed, 1 /* transpose */); + +#if MLKEM_K == 2 + poly_getnoise_eta1122_4x(sp.vec + 0, sp.vec + 1, ep.vec + 0, ep.vec + 1, + coins, 0, 1, 2, 3); + poly_getnoise_eta2(&epp, coins, 4); +#elif MLKEM_K == 3 + /* + * In this call, only the first three output buffers are needed. + * The last parameter is a dummy that's overwritten later. + */ + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, &b.vec[0], coins, 0, + 1, 2, 0xFF); + /* The fourth output buffer in this call _is_ used. */ + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, &epp, coins, 3, 4, + 5, 6); +#elif MLKEM_K == 4 + poly_getnoise_eta1_4x(sp.vec + 0, sp.vec + 1, sp.vec + 2, sp.vec + 3, coins, + 0, 1, 2, 3); + poly_getnoise_eta2_4x(ep.vec + 0, ep.vec + 1, ep.vec + 2, ep.vec + 3, coins, + 4, 5, 6, 7); + poly_getnoise_eta2(&epp, coins, 8); +#endif + + polyvec_ntt(&sp); + + polyvec_mulcache_compute(&sp_cache, &sp); + matvec_mul(&b, at, &sp, &sp_cache); + polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache); + + polyvec_invntt_tomont(&b); + poly_invntt_tomont(&v); + + /* Arithmetic cannot overflow, see static assertion at the top */ + polyvec_add(&b, &ep); + poly_add(&v, &epp); + poly_add(&v, &k); + + polyvec_reduce(&b); + poly_reduce(&v); + + pack_ciphertext(c, &b, &v); +} + +/* Check that the arithmetic in indcpa_dec() does not overflow */ +STATIC_ASSERT(INVNTT_BOUND + MLKEM_Q < INT16_MAX, indcpa_dec_bound_0) + +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +{ + polyvec b, skpv; + poly v, sb; + + unpack_ciphertext(&b, &v, c); + unpack_sk(&skpv, sk); + + polyvec_ntt(&b); + polyvec_basemul_acc_montgomery(&sb, &skpv, &b); + poly_invntt_tomont(&sb); + + /* Arithmetic cannot overflow, see static assertion at the top */ + poly_sub(&v, &sb); + poly_reduce(&v); + + poly_tomsg(m, &v); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.h new file mode 100644 index 0000000000..011f1aa4fe --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/indcpa.h @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef INDCPA_H +#define INDCPA_H + +#include +#include "cbmc.h" +#include "common.h" +#include "polyvec.h" + +#define gen_matrix MLKEM_NAMESPACE(gen_matrix) +/************************************************* + * Name: gen_matrix + * + * Description: Deterministically generate matrix A (or the transpose of A) + * from a seed. Entries of the matrix are polynomials that look + * uniformly random. Performs rejection sampling on output of + * a XOF + * + * Arguments: - polyvec *a: pointer to ouptput matrix A + * - const uint8_t *seed: pointer to input seed + * - int transposed: boolean deciding whether A or A^T is generated + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void gen_matrix(polyvec *a, const uint8_t seed[MLKEM_SYMBYTES], int transposed) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec) * MLKEM_K)) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires(transposed == 0 || transposed == 1) + assigns(object_whole(a)) + ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K, + array_bound(a[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))); +); + +#define indcpa_keypair_derand MLKEM_NAMESPACE(indcpa_keypair_derand) +/************************************************* + * Name: indcpa_keypair_derand + * + * Description: Generates public and private key for the CPA-secure + * public-key encryption scheme underlying ML-KEM + * + * Arguments: - uint8_t *pk: pointer to output public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES bytes) + * - uint8_t *sk: pointer to output private key + * (of length MLKEM_INDCPA_SECRETKEYBYTES bytes) + * - const uint8_t *coins: pointer to input randomness + * (of length MLKEM_SYMBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +#define indcpa_enc MLKEM_NAMESPACE(indcpa_enc) +/************************************************* + * Name: indcpa_enc + * + * Description: Encryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *c: pointer to output ciphertext + * (of length MLKEM_INDCPA_BYTES bytes) + * - const uint8_t *m: pointer to input message + * (of length MLKEM_INDCPA_MSGBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (of length MLKEM_INDCPA_PUBLICKEYBYTES) + * - const uint8_t *coins: pointer to input random coins used as + *seed (of length MLKEM_SYMBYTES) to deterministically generate all randomness + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCPA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(c)) +); + +#define indcpa_dec MLKEM_NAMESPACE(indcpa_dec) +/************************************************* + * Name: indcpa_dec + * + * Description: Decryption function of the CPA-secure + * public-key encryption scheme underlying Kyber. + * + * Arguments: - uint8_t *m: pointer to output decrypted message + * (of length MLKEM_INDCPA_MSGBYTES) + * - const uint8_t *c: pointer to input ciphertext + * (of length MLKEM_INDCPA_BYTES) + * - const uint8_t *sk: pointer to input secret key + * (of length MLKEM_INDCPA_SECRETKEYBYTES) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES], + const uint8_t c[MLKEM_INDCPA_BYTES], + const uint8_t sk[MLKEM_INDCPA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(c, MLKEM_INDCPA_BYTES)) + requires(memory_no_alias(m, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCPA_SECRETKEYBYTES)) + assigns(object_whole(m)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.c new file mode 100644 index 0000000000..5779d3273a --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.c @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include +#include + +#include "indcpa.h" +#include "kem.h" +#include "randombytes.h" +#include "symmetric.h" +#include "verify.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define check_pk MLKEM_NAMESPACE(check_pk) +#define check_sk MLKEM_NAMESPACE(check_sk) +/* End of static namespacing */ + +#if defined(CBMC) +/* Redeclaration with contract needed for CBMC only */ +int memcmp(const void *str1, const void *str2, size_t n) +__contract__( + requires(memory_no_alias(str1, n)) + requires(memory_no_alias(str2, n)) +); +#endif + +/************************************************* + * Name: check_pk + * + * Description: Implements modulus check mandated by FIPS203, + * i.e., ensures that coefficients are in [0,q-1]. + * Described in Section 7.2 of FIPS203. + * + * Arguments: - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + polyvec p; + uint8_t p_reencoded[MLKEM_POLYVECBYTES]; + polyvec_frombytes(&p, pk); + polyvec_reduce(&p); + polyvec_tobytes(p_reencoded, &p); + /* Data is public, so a variable-time memcmp() is OK */ + if (memcmp(pk, p_reencoded, MLKEM_POLYVECBYTES)) + { + return -1; + } + return 0; +} + +/************************************************* + * Name: check_sk + * + * Description: Implements public key hash check mandated by FIPS203, + * i.e., ensures that + * sk[768𝑘+32 ∶ 768𝑘+64] = H(pk)= H(sk[384𝑘 : 768𝑘+32]) + * Described in Section 7.3 of FIPS203. + * + * Arguments: - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 on success, and -1 on failure + **************************************************/ +static int check_sk(const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t test[MLKEM_SYMBYTES]; + /* + * The parts of `sk` being hashed and compared here are public, so + * no public information is leaked through the runtime or the return value + * of this function. + */ + hash_h(test, sk + MLKEM_INDCPA_SECRETKEYBYTES, MLKEM_INDCCA_PUBLICKEYBYTES); + if (memcmp(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, test, + MLKEM_SYMBYTES)) + { + return -1; + } + return 0; +} + +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +{ + indcpa_keypair_derand(pk, sk, coins); + memcpy(sk + MLKEM_INDCPA_SECRETKEYBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_h(sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, pk, + MLKEM_INDCCA_PUBLICKEYBYTES); + /* Value z for pseudo-random output on reject */ + memcpy(sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + coins + MLKEM_SYMBYTES, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + ALIGN uint8_t coins[2 * MLKEM_SYMBYTES]; + randombytes(coins, 2 * MLKEM_SYMBYTES); + crypto_kem_keypair_derand(pk, sk, coins); + return 0; +} + +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +{ + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + + if (check_pk(pk)) + { + return -1; + } + + memcpy(buf, coins, MLKEM_SYMBYTES); + + /* Multitarget countermeasure for coins + contributory KEM */ + hash_h(buf + MLKEM_SYMBYTES, pk, MLKEM_INDCCA_PUBLICKEYBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(ct, buf, pk, kr + MLKEM_SYMBYTES); + + memcpy(ss, kr, MLKEM_SYMBYTES); + return 0; +} + +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +{ + ALIGN uint8_t coins[MLKEM_SYMBYTES]; + randombytes(coins, MLKEM_SYMBYTES); + return crypto_kem_enc_derand(ct, ss, pk, coins); +} + +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +{ + uint8_t fail; + ALIGN uint8_t buf[2 * MLKEM_SYMBYTES]; + /* Will contain key, coins */ + ALIGN uint8_t kr[2 * MLKEM_SYMBYTES]; + const uint8_t *pk = sk + MLKEM_INDCPA_SECRETKEYBYTES; + + if (check_sk(sk)) + { + return -1; + } + + indcpa_dec(buf, ct, sk); + + /* Multitarget countermeasure for coins + contributory KEM */ + memcpy(buf + MLKEM_SYMBYTES, + sk + MLKEM_INDCCA_SECRETKEYBYTES - 2 * MLKEM_SYMBYTES, MLKEM_SYMBYTES); + hash_g(kr, buf, 2 * MLKEM_SYMBYTES); + + /* Recompute and compare ciphertext */ + { + /* Temporary buffer */ + ALIGN uint8_t cmp[MLKEM_INDCCA_CIPHERTEXTBYTES]; + /* coins are in kr+MLKEM_SYMBYTES */ + indcpa_enc(cmp, buf, pk, kr + MLKEM_SYMBYTES); + fail = ct_memcmp(ct, cmp, MLKEM_INDCCA_CIPHERTEXTBYTES); + } + + /* Compute rejection key */ + { + /* Temporary buffer */ + ALIGN uint8_t tmp[MLKEM_SYMBYTES + MLKEM_INDCCA_CIPHERTEXTBYTES]; + memcpy(tmp, sk + MLKEM_INDCCA_SECRETKEYBYTES - MLKEM_SYMBYTES, + MLKEM_SYMBYTES); + memcpy(tmp + MLKEM_SYMBYTES, ct, MLKEM_INDCCA_CIPHERTEXTBYTES); + hash_j(ss, tmp, sizeof(tmp)); + } + + /* Copy true key to return buffer if fail is 0 */ + ct_cmov_zero(ss, kr, MLKEM_SYMBYTES, fail); + + return 0; +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.h new file mode 100644 index 0000000000..074e4771e4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/kem.h @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef KEM_H +#define KEM_H + +#include +#include "cbmc.h" +#include "common.h" + +/* Include to ensure consistency between internal kem.h + * and external mlkem_native.h. */ +#include "mlkem_native.h" + +#if MLKEM_INDCCA_SECRETKEYBYTES != MLKEM_SECRETKEYBYTES(MLKEM_LVL) +#error Mismatch for SECRETKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_PUBLICKEYBYTES != MLKEM_PUBLICKEYBYTES(MLKEM_LVL) +#error Mismatch for PUBLICKEYBYTES between kem.h and mlkem_native.h +#endif + +#if MLKEM_INDCCA_CIPHERTEXTBYTES != MLKEM_CIPHERTEXTBYTES(MLKEM_LVL) +#error Mismatch for CIPHERTEXTBYTES between kem.h and mlkem_native.h +#endif + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * - uint8_t *coins: pointer to input randomness + * (an already allocated array filled with 2*MLKEM_SYMBYTES + * random bytes) + ** + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair_derand(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES], + const uint8_t *coins) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + requires(memory_no_alias(coins, 2 * MLKEM_SYMBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - uint8_t *sk: pointer to output private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + * bytes) + * + * Returns 0 (success) + **************************************************/ +int crypto_kem_keypair(uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(pk)) + assigns(object_whole(sk)) +); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + * bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + * bytes) + * - const uint8_t *coins: pointer to input randomness + * (an already allocated array filled with MLKEM_SYMBYTES random + * bytes) + ** + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc_derand(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES], + const uint8_t coins[MLKEM_SYMBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + requires(memory_no_alias(coins, MLKEM_SYMBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *pk: pointer to input public key + * (an already allocated array of MLKEM_INDCCA_PUBLICKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int crypto_kem_enc(uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + uint8_t ss[MLKEM_SSBYTES], + const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES]) +__contract__( + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(pk, MLKEM_INDCCA_PUBLICKEYBYTES)) + assigns(object_whole(ct)) + assigns(object_whole(ss)) +); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret + * (an already allocated array of MLKEM_SSBYTES bytes) + * - const uint8_t *ct: pointer to input cipher text + * (an already allocated array of MLKEM_INDCCA_CIPHERTEXTBYTES + *bytes) + * - const uint8_t *sk: pointer to input private key + * (an already allocated array of MLKEM_INDCCA_SECRETKEYBYTES + *bytes) + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int crypto_kem_dec(uint8_t ss[MLKEM_SSBYTES], + const uint8_t ct[MLKEM_INDCCA_CIPHERTEXTBYTES], + const uint8_t sk[MLKEM_INDCCA_SECRETKEYBYTES]) +__contract__( + requires(memory_no_alias(ss, MLKEM_SSBYTES)) + requires(memory_no_alias(ct, MLKEM_INDCCA_CIPHERTEXTBYTES)) + requires(memory_no_alias(sk, MLKEM_INDCCA_SECRETKEYBYTES)) + assigns(object_whole(ss)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/mlkem_native.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/mlkem_native.h new file mode 100644 index 0000000000..4aed4efbba --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/mlkem_native.h @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Public API for mlkem-native + * + * This header defines the public API of a single build of mlkem-native. + * + * To use this header, make sure one of the following holds: + * + * - The config.h used for the build is available in the include paths. + * - The values of BUILD_INFO_LVL and BUILD_INFO_NAMESPACE are set, reflecting + * the security level (512/768/1024) and namespace of the build. + * + * This header specifies a build of mlkem-native for a fixed security level. + * If you need multiple builds, e.g. to build a library offering multiple + * security levels, you need multiple instances of this header. + */ + +/* NOTE: To use multiple instances of this header, use separate guards. */ +#ifndef MLKEM_NATIVE_H +#define MLKEM_NATIVE_H + +#include + +/*************************** Build information ********************************/ + +/* + * Provide security level (BUILD_INFO_LVL) and namespacing + * (BUILD_INFO_NAMESPACE) + * + * By default, this is extracted from the configuration used for the build, + * but you can also set it manually to avoid a dependency on the build config. + */ + +/* Skip this if BUILD_INFO_LVL has already been set */ +#if !defined(BUILD_INFO_LVL) + +/* Option 1: Extract from config */ +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif + +#if MLKEM_K == 2 +#define BUILD_INFO_LVL 512 +#elif MLKEM_K == 3 +#define BUILD_INFO_LVL 768 +#elif MLKEM_K == 4 +#define BUILD_INFO_LVL 1024 +#else +#error MLKEM_K not set by config file +#endif + +#ifndef MLKEM_NAMESPACE_PREFIX +#error MLKEM_NAMESPACE_PREFIX not set by config file +#endif + +#define BUILD_INFO_CONCAT_(x, y) x##_##y +#define BUILD_INFO_CONCAT(x, y) BUILD_INFO_CONCAT_(x, y) +#define BUILD_INFO_NAMESPACE(sym) BUILD_INFO_CONCAT(MLKEM_NAMESPACE_PREFIX, sym) + +#endif /* BUILD_INFO_LVL */ + +/* Option 2: Provide BUILD_INFO_LVL and BUILD_INFO_NAMESPACE manually */ + +/* #define BUILD_INFO_LVL ADJUSTME */ +/* #define BUILD_INFO_NAMESPACE(sym) ADJUSTME */ + +/******************************* Key sizes ************************************/ + +/* Sizes of cryptographic material, per level */ +#define MLKEM512_SECRETKEYBYTES 1632 +#define MLKEM512_PUBLICKEYBYTES 800 +#define MLKEM512_CIPHERTEXTBYTES 768 + +#define MLKEM768_SECRETKEYBYTES 2400 +#define MLKEM768_PUBLICKEYBYTES 1184 +#define MLKEM768_CIPHERTEXTBYTES 1088 + +#define MLKEM1024_SECRETKEYBYTES 3168 +#define MLKEM1024_PUBLICKEYBYTES 1568 +#define MLKEM1024_CIPHERTEXTBYTES 1568 + +/* Size of randomness coins in bytes (level-independent) */ +#define MLKEM_SYMBYTES 32 +#define MLKEM512_SYMBYTES MLKEM_SYMBYTES +#define MLKEM768_SYMBYTES MLKEM_SYMBYTES +#define MLKEM1024_SYMBYTES MLKEM_SYMBYTES +/* Size of shared secret in bytes (level-independent) */ +#define MLKEM_BYTES 32 +#define MLKEM512_BYTES MLKEM_BYTES +#define MLKEM768_BYTES MLKEM_BYTES +#define MLKEM1024_BYTES MLKEM_BYTES + +/* Sizes of cryptographic material, as a function of LVL=512,768,1024 */ +#define MLKEM_SECRETKEYBYTES_(LVL) MLKEM##LVL##_SECRETKEYBYTES +#define MLKEM_PUBLICKEYBYTES_(LVL) MLKEM##LVL##_PUBLICKEYBYTES +#define MLKEM_CIPHERTEXTBYTES_(LVL) MLKEM##LVL##_CIPHERTEXTBYTES +#define MLKEM_SECRETKEYBYTES(LVL) MLKEM_SECRETKEYBYTES_(LVL) +#define MLKEM_PUBLICKEYBYTES(LVL) MLKEM_PUBLICKEYBYTES_(LVL) +#define MLKEM_CIPHERTEXTBYTES(LVL) MLKEM_CIPHERTEXTBYTES_(LVL) + +/****************************** Function API **********************************/ + +/************************************************* + * Name: crypto_kem_keypair_derand + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t pk[]: pointer to output public key, an array of + * length MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t sk[]: pointer to output private key, an array of + * of MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * - uint8_t *coins: pointer to input randomness, an array of + * 2*MLKEM_SYMBYTES uniformly random bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair_derand)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)], const uint8_t *coins); + +/************************************************* + * Name: crypto_kem_keypair + * + * Description: Generates public and private key + * for CCA-secure ML-KEM key encapsulation mechanism + * + * Arguments: - uint8_t *pk: pointer to output public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - uint8_t *sk: pointer to output private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 (success) + **************************************************/ +int BUILD_INFO_NAMESPACE(keypair)( + uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_enc_derand + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * - const uint8_t *coins: pointer to input randomness, an array of + * MLKEM_SYMBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc_derand)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)], + const uint8_t coins[MLKEM_SYMBYTES]); + +/************************************************* + * Name: crypto_kem_enc + * + * Description: Generates cipher text and shared + * secret for given public key + * + * Arguments: - uint8_t *ct: pointer to output cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *pk: pointer to input public key, an array of + * MLKEM{512,768,1024}_PUBLICKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the public key modulus check (see Section 7.2 + * of FIPS203) fails. + **************************************************/ +int BUILD_INFO_NAMESPACE(enc)( + uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], uint8_t ss[MLKEM_BYTES], + const uint8_t pk[MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL)]); + +/************************************************* + * Name: crypto_kem_dec + * + * Description: Generates shared secret for given + * cipher text and private key + * + * Arguments: - uint8_t *ss: pointer to output shared secret, an array of + * MLKEM_BYTES bytes. + * - const uint8_t *ct: pointer to input cipher text, an array of + * MLKEM{512,768,1024}_CIPHERTEXTBYTES bytes. + * - const uint8_t *sk: pointer to input private key, an array of + * MLKEM{512,768,1024}_SECRETKEYBYTES bytes. + * + * Returns 0 on success, and -1 if the secret key hash check (see Section 7.3 of + * FIPS203) fails. + * + * On failure, ss will contain a pseudo-random value. + **************************************************/ +int BUILD_INFO_NAMESPACE(dec)( + uint8_t ss[MLKEM_BYTES], + const uint8_t ct[MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL)], + const uint8_t sk[MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL)]); + +/****************************** Standard API *********************************/ + +/* If desired, export API in CRYPTO_xxx and crypto_kem_xxx format as used + * e.g. by SUPERCOP and NIST. + * + * Remove this if you don't need it, or if you need multiple instances + * of this header. */ + +#if !defined(BUILD_INFO_NO_STANDARD_API) +#define CRYPTO_SECRETKEYBYTES MLKEM_SECRETKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_PUBLICKEYBYTES MLKEM_PUBLICKEYBYTES(BUILD_INFO_LVL) +#define CRYPTO_CIPHERTEXTBYTES MLKEM_CIPHERTEXTBYTES(BUILD_INFO_LVL) + +#define CRYPTO_SYMBYTES MLKEM_SYMBYTES +#define CRYPTO_BYTES MLKEM_BYTES + +#define crypto_kem_keypair_derand BUILD_INFO_NAMESPACE(keypair_derand) +#define crypto_kem_keypair BUILD_INFO_NAMESPACE(keypair) +#define crypto_kem_enc_derand BUILD_INFO_NAMESPACE(enc_derand) +#define crypto_kem_enc BUILD_INFO_NAMESPACE(enc) +#define crypto_kem_dec BUILD_INFO_NAMESPACE(dec) +#endif /* BUILD_INFO_NO_STANDARD_API */ + +/********************************* Cleanup ************************************/ + +/* Unset build information to allow multiple instances of this header. + * Keep this commented out when using the standard API. */ +/* #undef BUILD_INFO_LVL */ +/* #undef BUILD_INFO_NAMESPACE */ + +#endif /* MLKEM_NATIVE_API_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.c new file mode 100644 index 0000000000..02b45215c2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include + +#include "arith_backend.h" +#include "debug/debug.h" +#include "ntt.h" +#include "reduce.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define ntt_butterfly_block MLKEM_NAMESPACE(ntt_butterfly_block) +#define ntt_layer MLKEM_NAMESPACE(ntt_layer) +#define invntt_layer MLKEM_NAMESPACE(invntt_layer) +/* End of static namespacing */ + +#if !defined(MLKEM_USE_NATIVE_NTT) +/* + * Computes a block CT butterflies with a fixed twiddle factor, + * using Montgomery multiplication. + * Parameters: + * - r: Pointer to base of polynomial (_not_ the base of butterfly block) + * - root: Twiddle factor to use for the butterfly. This must be in + * Montgomery form and signed canonical. + * - start: Offset to the beginning of the butterfly block + * - len: Index difference between coefficients subject to a butterfly + * - bound: Ghost variable describing coefficient bound: Prior to `start`, + * coefficients must be bound by `bound + MLKEM_Q`. Post `start`, + * they must be bound by `bound`. + * When this function returns, output coefficients in the index range + * [start, start+2*len) have bound bumped to `bound + MLKEM_Q`. + * Example: + * - start=8, len=4 + * This would compute the following four butterflies + * 8 -- 12 + * 9 -- 13 + * 10 -- 14 + * 11 -- 15 + * - start=4, len=2 + * This would compute the following two butterflies + * 4 -- 6 + * 5 -- 7 + */ +static void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, + int len, int bound) +__contract__( + requires(0 <= start && start < MLKEM_N) + requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) + requires(0 <= bound && bound < INT16_MAX - MLKEM_Q) + requires(-HALF_Q < zeta && zeta < HALF_Q) + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(array_abs_bound(r, 0, start, bound + MLKEM_Q)) + requires(array_abs_bound(r, start, MLKEM_N, bound)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q)) + ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound))) +{ + /* `bound` is a ghost variable only needed in the CBMC specification */ + int j; + ((void)bound); + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + /* + * Coefficients are updated in strided pairs, so the bounds for the + * intermediate states alternate twice between the old and new bound + */ + invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j, start + len, bound)) + invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q)) + invariant(array_abs_bound(r, j + len, MLKEM_N, bound))) + { + int16_t t; + t = fqmul(r[j + len], zeta); + r[j + len] = r[j] - t; + r[j] = r[j] + t; + } +} + +/* + *Compute one layer of forward NTT + * Parameters: + * - r: Pointer to base of polynomial + * - len: Stride of butterflies in this layer. + * - layer: Ghost variable indicating which layer is being applied. + * Must match `len` via `len == MLKEM_N >> layer`. + * Note: `len` could be dropped and computed in the function, but + * we are following the structure of the reference NTT from the + * official Kyber implementation here, merely adding `layer` as + * a ghost variable for the specifications. + */ +static void ntt_layer(int16_t r[MLKEM_N], int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(1 <= layer && layer <= 7 && len == (MLKEM_N >> layer)) + requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable only needed in the CBMC specification */ + ((void)layer); + /* Twiddle factors for layer n start at index 2^(layer-1) */ + k = MLKEM_N / (2 * len); + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(0 <= start && start < MLKEM_N + 2 * len) + invariant(0 <= k && k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N) + invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q)) + invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q))) + { + int16_t zeta = zetas[k++]; + ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q); + } +} + +/* + * Compute full forward NTT + * NOTE: This particular implementation satisfies a much tighter + * bound on the output coefficients (5*q) than the contractual one (8*q), + * but this is not needed in the calling code. Should we change the + * base multiplication strategy to require smaller NTT output bounds, + * the proof may need strengthening. + */ + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + int len, layer; + int16_t *r; + POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); + r = p->coeffs; + + for (len = 128, layer = 1; len >= 2; len >>= 1, layer++) + __loop__( + invariant(1 <= layer && layer <= 8 && len == (MLKEM_N >> layer)) + invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))) + { + ntt_layer(r, len, layer); + } + + /* Check the stronger bound */ + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); +} +#else /* MLKEM_USE_NATIVE_NTT */ + +/* Check that bound for native NTT implies contractual bound */ +STATIC_ASSERT(NTT_BOUND_NATIVE <= NTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *p) +{ + POLY_BOUND_MSG(p, MLKEM_Q, "native ntt input"); + ntt_native(p); + POLY_BOUND_MSG(p, NTT_BOUND_NATIVE, "native ntt output"); +} +#endif /* MLKEM_USE_NATIVE_NTT */ + +#if !defined(MLKEM_USE_NATIVE_INTT) + +/* Check that bound for reference invNTT implies contractual bound */ +#define INVNTT_BOUND_REF (3 * MLKEM_Q / 4) +STATIC_ASSERT(INVNTT_BOUND_REF <= INVNTT_BOUND, invntt_bound) + +/* Compute one layer of inverse NTT */ +static void invntt_layer(int16_t *r, int len, int layer) +__contract__( + requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N)) + requires(2 <= len && len <= 128 && 1 <= layer && layer <= 7) + requires(len == (1 << (8 - layer))) + requires(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N)) + ensures(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) +{ + int start, k; + /* `layer` is a ghost variable used only in the specification */ + ((void)layer); + k = MLKEM_N / len - 1; + for (start = 0; start < MLKEM_N; start += 2 * len) + __loop__( + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q)) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + /* Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) */ + invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) + { + int j; + int16_t zeta = zetas[k--]; + for (j = start; j < start + len; j++) + __loop__( + invariant(start <= j && j <= start + len) + invariant(0 <= start && start <= MLKEM_N && 0 <= k && k <= 127) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + int16_t t = r[j]; + r[j] = barrett_reduce(t + r[j + len]); + r[j + len] = r[j + len] - t; + r[j + len] = fqmul(r[j + len], zeta); + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + /* + * Scale input polynomial to account for Montgomery factor + * and NTT twist. This also brings coefficients down to + * absolute value < MLKEM_Q. + */ + int j, len, layer; + const int16_t f = 1441; + int16_t *r = p->coeffs; + + for (j = 0; j < MLKEM_N; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N) + invariant(array_abs_bound(r, 0, j, MLKEM_Q))) + { + r[j] = fqmul(r[j], f); + } + + /* Run the invNTT layers */ + for (len = 2, layer = 7; len <= 128; len <<= 1, layer--) + __loop__( + invariant(2 <= len && len <= 256 && 0 <= layer && layer <= 7 && len == (1 << (8 - layer))) + invariant(array_abs_bound(r, 0, MLKEM_N, MLKEM_Q))) + { + invntt_layer(p->coeffs, len, layer); + } + + POLY_BOUND_MSG(p, INVNTT_BOUND_REF, "ref intt output"); +} +#else /* MLKEM_USE_NATIVE_INTT */ + +/* Check that bound for native invNTT implies contractual bound */ +STATIC_ASSERT(INVNTT_BOUND_NATIVE <= INVNTT_BOUND, invntt_bound) + +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *p) +{ + intt_native(p); + POLY_BOUND_MSG(p, INVNTT_BOUND_NATIVE, "native intt output"); +} +#endif /* MLKEM_USE_NATIVE_INTT */ + +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +{ + int32_t t0, t1; + + BOUND(a, 2, 4096, "basemul input bound"); + + t0 = (int32_t)a[1] * b_cached; + t0 += (int32_t)a[0] * b[0]; + t1 = (int32_t)a[0] * b[1]; + t1 += (int32_t)a[1] * b[0]; + + /* |ti| < 2 * q * 2^15 */ + r[0] = montgomery_reduce(t0); + r[1] = montgomery_reduce(t1); + + BOUND(r, 2, 2 * MLKEM_Q, "basemul output bound"); +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.h new file mode 100644 index 0000000000..5592bb9a27 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/ntt.h @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef NTT_H +#define NTT_H + +#include +#include "cbmc.h" +#include "common.h" +#include "poly.h" +#include "reduce.h" + +#define zetas MLKEM_NAMESPACE(zetas) +extern const int16_t zetas[128]; + +#define poly_ntt MLKEM_NAMESPACE(poly_ntt) +/************************************************* + * Name: poly_ntt + * + * Description: Computes negacyclic number-theoretic transform (NTT) of + * a polynomial in place. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * (NOTE: Sometimes the input to the NTT is actually smaller, + * which gives better bounds.) + * + * Arguments: - poly *p: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_ntt(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, NTT_BOUND)) +); + +#define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) +/************************************************* + * Name: poly_invntt_tomont + * + * Description: Computes inverse of negacyclic number-theoretic transform (NTT) + * of a polynomial in place; + * inputs assumed to be in bitreversed order, output in normal + * order + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * Arguments: - uint16_t *a: pointer to in/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_invntt_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, INVNTT_BOUND)) +); + +#define basemul_cached MLKEM_NAMESPACE(basemul_cached) +/************************************************************ + * Name: basemul_cached + * + * Description: Computes a representative modulo q of + * (a0*b0 + a1*b_cached, a0*b1 + a1*b0)/65536 + * + * If b_cached is b1*zeta, this represents the + * product of (a0 + a1*X) and (b0 + b1*X) in + * Fq[X]/(X^2 - zeta). + * + * Arguments: - r: Pointer to output polynomial + * Upon return, coefficients are bound by + * 2*MLKEM_Q in absolute value. + * - a: Pointer to first input polynomial + * Must be coefficient-wise < 4096 in absolute value. + * - b: Pointer to second input polynomial + * Can have arbitrary int16_t coefficients + * - b_cached: Some precomputed value, typically derived from + * b1 and a twiddle factor. Can be an arbitary int16_t. + ************************************************************/ +MLKEM_NATIVE_INTERNAL_API +void basemul_cached(int16_t r[2], const int16_t a[2], const int16_t b[2], + int16_t b_cached) +__contract__( + requires(memory_no_alias(r, 2 * sizeof(int16_t))) + requires(memory_no_alias(a, 2 * sizeof(int16_t))) + requires(memory_no_alias(b, 2 * sizeof(int16_t))) + requires(array_bound(a, 0, 2, 0, UINT12_LIMIT)) + assigns(memory_slice(r, 2 * sizeof(int16_t))) + ensures(array_abs_bound(r, 0, 2, 2 * MLKEM_Q)) +); + + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/params.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/params.h new file mode 100644 index 0000000000..fa751f977b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/params.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef PARAMS_H +#define PARAMS_H + +#if defined(MLKEM_NATIVE_CONFIG_FILE) +#include MLKEM_NATIVE_CONFIG_FILE +#else +#include "config.h" +#endif /* MLKEM_NATIVE_CONFIG_FILE */ + +#if !defined(MLKEM_K) +#error MLKEM_K is not defined +#endif + +#define MLKEM_N 256 +#define MLKEM_Q 3329 +#define UINT12_LIMIT 4096 + +#define MLKEM_SYMBYTES 32 /* size in bytes of hashes, and seeds */ +#define MLKEM_SSBYTES 32 /* size in bytes of shared key */ + +#define MLKEM_POLYBYTES 384 +#define MLKEM_POLYVECBYTES (MLKEM_K * MLKEM_POLYBYTES) + +#if MLKEM_K == 2 +#define MLKEM_LVL 512 +#define MLKEM_ETA1 3 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 3 +#define MLKEM_LVL 768 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 128 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 320 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#elif MLKEM_K == 4 +#define MLKEM_LVL 1024 +#define MLKEM_ETA1 2 +#define MLKEM_POLYCOMPRESSEDBYTES_DV 160 +#define MLKEM_POLYCOMPRESSEDBYTES_DU 352 +#define MLKEM_POLYVECCOMPRESSEDBYTES_DU (MLKEM_K * MLKEM_POLYCOMPRESSEDBYTES_DU) +#endif + +#define MLKEM_ETA2 2 + +#define MLKEM_INDCPA_MSGBYTES (MLKEM_SYMBYTES) +#define MLKEM_INDCPA_PUBLICKEYBYTES (MLKEM_POLYVECBYTES + MLKEM_SYMBYTES) +#define MLKEM_INDCPA_SECRETKEYBYTES (MLKEM_POLYVECBYTES) +#define MLKEM_INDCPA_BYTES \ + (MLKEM_POLYVECCOMPRESSEDBYTES_DU + MLKEM_POLYCOMPRESSEDBYTES_DV) + +#define MLKEM_INDCCA_PUBLICKEYBYTES (MLKEM_INDCPA_PUBLICKEYBYTES) +/* 32 bytes of additional space to save H(pk) */ +#define MLKEM_INDCCA_SECRETKEYBYTES \ + (MLKEM_INDCPA_SECRETKEYBYTES + MLKEM_INDCPA_PUBLICKEYBYTES + \ + 2 * MLKEM_SYMBYTES) +#define MLKEM_INDCCA_CIPHERTEXTBYTES (MLKEM_INDCPA_BYTES) + +#define KECCAK_WAY 4 +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.c new file mode 100644 index 0000000000..5807879df4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.c @@ -0,0 +1,583 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include +#include + +#include "arith_backend.h" +#include "cbd.h" +#include "cbmc.h" +#include "debug/debug.h" +#include "fips202x4.h" +#include "ntt.h" +#include "poly.h" +#include "reduce.h" +#include "symmetric.h" +#include "verify.h" + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 8)) + { + unsigned k; + uint16_t t[8]; + for (k = 0; k < 8; k++) + __loop__( + invariant(k >= 0 && k <= 8) + invariant(forall(r, 0, k, t[r] < (1u << 11)))) + { + t[k] = scalar_compress_d11(a->coeffs[8 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 11-bit in size. + */ + r[11 * j + 0] = (t[0] >> 0) & 0xFF; + r[11 * j + 1] = (t[0] >> 8) | ((t[1] << 3) & 0xFF); + r[11 * j + 2] = (t[1] >> 5) | ((t[2] << 6) & 0xFF); + r[11 * j + 3] = (t[2] >> 2) & 0xFF; + r[11 * j + 4] = (t[2] >> 10) | ((t[3] << 1) & 0xFF); + r[11 * j + 5] = (t[3] >> 7) | ((t[4] << 4) & 0xFF); + r[11 * j + 6] = (t[4] >> 4) | ((t[5] << 7) & 0xFF); + r[11 * j + 7] = (t[5] >> 1) & 0xFF; + r[11 * j + 8] = (t[5] >> 9) | ((t[6] << 2) & 0xFF); + r[11 * j + 9] = (t[6] >> 6) | ((t[7] << 5) & 0xFF); + r[11 * j + 10] = (t[7] >> 3); + } + +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__(invariant(j >= 0 && j <= MLKEM_N / 4)) + { + unsigned k; + uint16_t t[4]; + for (k = 0; k < 4; k++) + __loop__( + invariant(k >= 0 && k <= 4) + invariant(forall(r, 0, k, t[r] < (1u << 10)))) + { + t[k] = scalar_compress_d10(a->coeffs[4 * j + k]); + } + + /* + * Make all implicit truncation explicit. No data is being + * truncated for the LHS's since each t[i] is 10-bit in size. + */ + r[5 * j + 0] = (t[0] >> 0) & 0xFF; + r[5 * j + 1] = (t[0] >> 8) | ((t[1] << 2) & 0xFF); + r[5 * j + 2] = (t[1] >> 6) | ((t[2] << 4) & 0xFF); + r[5 * j + 3] = (t[2] >> 4) | ((t[3] << 6) & 0xFF); + r[5 * j + 4] = (t[3] >> 2); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +{ + unsigned j; +#if (MLKEM_POLYCOMPRESSEDBYTES_DU == 352) + for (j = 0; j < MLKEM_N / 8; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[8]; + uint8_t const *base = &a[11 * j]; + t[0] = 0x7FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x7FF & ((base[1] >> 3) | ((uint16_t)base[2] << 5)); + t[2] = 0x7FF & ((base[2] >> 6) | ((uint16_t)base[3] << 2) | + ((uint16_t)base[4] << 10)); + t[3] = 0x7FF & ((base[4] >> 1) | ((uint16_t)base[5] << 7)); + t[4] = 0x7FF & ((base[5] >> 4) | ((uint16_t)base[6] << 4)); + t[5] = 0x7FF & ((base[6] >> 7) | ((uint16_t)base[7] << 1) | + ((uint16_t)base[8] << 9)); + t[6] = 0x7FF & ((base[8] >> 2) | ((uint16_t)base[9] << 6)); + t[7] = 0x7FF & ((base[9] >> 5) | ((uint16_t)base[10] << 3)); + + for (k = 0; k < 8; k++) + __loop__( + invariant(0 <= k && k <= 8) + invariant(array_bound(r->coeffs, 0, 8 * j + k, 0, MLKEM_Q))) + { + r->coeffs[8 * j + k] = scalar_decompress_d11(t[k]); + } + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DU == 320) + for (j = 0; j < MLKEM_N / 4; j++) + __loop__( + invariant(0 <= j && j <= MLKEM_N / 4) + invariant(array_bound(r->coeffs, 0, 4 * j, 0, MLKEM_Q))) + { + int k; + uint16_t t[4]; + uint8_t const *base = &a[5 * j]; + + t[0] = 0x3FF & ((base[0] >> 0) | ((uint16_t)base[1] << 8)); + t[1] = 0x3FF & ((base[1] >> 2) | ((uint16_t)base[2] << 6)); + t[2] = 0x3FF & ((base[2] >> 4) | ((uint16_t)base[3] << 4)); + t[3] = 0x3FF & ((base[3] >> 6) | ((uint16_t)base[4] << 2)); + + for (k = 0; k < 4; k++) + __loop__( + invariant(0 <= k && k <= 4) + invariant(array_bound(r->coeffs, 0, 4 * j + k, 0, MLKEM_Q))) + { + r->coeffs[4 * j + k] = scalar_decompress_d10(t[k]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DU needs to be in {320,352}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 16))) + { + t[j] = scalar_compress_d4(a->coeffs[8 * i + j]); + } + + r[i * 4] = t[0] | (t[1] << 4); + r[i * 4 + 1] = t[2] | (t[3] << 4); + r[i * 4 + 2] = t[4] | (t[5] << 4); + r[i * 4 + 3] = t[6] | (t[7] << 4); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + uint8_t t[8] = {0}; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(t, 0, j, 0, 32))) + { + t[j] = scalar_compress_d5(a->coeffs[8 * i + j]); + } + + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC, and use array indexing into + * r rather than pointer-arithmetic to simplify verification + */ + r[i * 5] = 0xFF & ((t[0] >> 0) | (t[1] << 5)); + r[i * 5 + 1] = 0xFF & ((t[1] >> 3) | (t[2] << 2) | (t[3] << 7)); + r[i * 5 + 2] = 0xFF & ((t[3] >> 1) | (t[4] << 4)); + r[i * 5 + 3] = 0xFF & ((t[4] >> 4) | (t[5] << 1) | (t[6] << 6)); + r[i * 5 + 4] = 0xFF & ((t[6] >> 2) | (t[7] << 3)); + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif +} + +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +{ + unsigned i; +#if (MLKEM_POLYCOMPRESSEDBYTES_DV == 128) + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, MLKEM_Q))) + { + r->coeffs[2 * i + 0] = scalar_decompress_d4((a[i] >> 0) & 0xF); + r->coeffs[2 * i + 1] = scalar_decompress_d4((a[i] >> 4) & 0xF); + } +#elif (MLKEM_POLYCOMPRESSEDBYTES_DV == 160) + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + uint8_t t[8]; + const int offset = i * 5; + /* + * Explicitly truncate to avoid warning about + * implicit truncation in CBMC and unwind loop for ease + * of proof. + */ + + /* + * Decompress 5 8-bit bytes (so 40 bits) into + * 8 5-bit values stored in t[] + */ + t[0] = 0x1F & (a[offset + 0] >> 0); + t[1] = 0x1F & ((a[offset + 0] >> 5) | (a[offset + 1] << 3)); + t[2] = 0x1F & (a[offset + 1] >> 2); + t[3] = 0x1F & ((a[offset + 1] >> 7) | (a[offset + 2] << 1)); + t[4] = 0x1F & ((a[offset + 2] >> 4) | (a[offset + 3] << 4)); + t[5] = 0x1F & (a[offset + 3] >> 1); + t[6] = 0x1F & ((a[offset + 3] >> 6) | (a[offset + 4] << 2)); + t[7] = 0x1F & (a[offset + 4] >> 3); + + /* and copy to the correct slice in r[] */ + for (j = 0; j < 8; j++) + __loop__( + invariant(j >= 0 && j <= 8 && i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + r->coeffs[8 * i + j] = scalar_decompress_d5(t[j]); + } + } +#else +#error "MLKEM_POLYCOMPRESSEDBYTES_DV needs to be in {128, 160}" +#endif + + POLY_UBOUND(r, MLKEM_Q); +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + + for (i = 0; i < MLKEM_N / 2; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 2)) + { + const uint16_t t0 = a->coeffs[2 * i]; + const uint16_t t1 = a->coeffs[2 * i + 1]; + /* + * t0 and t1 are both < MLKEM_Q, so contain at most 12 bits each of + * significant data, so these can be packed into 24 bits or exactly + * 3 bytes, as follows. + */ + + /* Least significant bits 0 - 7 of t0. */ + r[3 * i + 0] = t0 & 0xFF; + + /* + * Most significant bits 8 - 11 of t0 become the least significant + * nibble of the second byte. The least significant 4 bits + * of t1 become the upper nibble of the second byte. + */ + r[3 * i + 1] = (t0 >> 8) | ((t1 << 4) & 0xF0); + + /* Bits 4 - 11 of t1 become the third byte. */ + r[3 * i + 2] = t1 >> 4; + } +} +#else /* MLKEM_USE_NATIVE_POLY_TOBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +{ + POLY_UBOUND(a, MLKEM_Q); + poly_tobytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOBYTES */ + +#if !defined(MLKEM_USE_NATIVE_POLY_FROMBYTES) +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 2; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 2) + invariant(array_bound(r->coeffs, 0, 2 * i, 0, UINT12_LIMIT))) + { + const uint8_t t0 = a[3 * i + 0]; + const uint8_t t1 = a[3 * i + 1]; + const uint8_t t2 = a[3 * i + 2]; + r->coeffs[2 * i + 0] = t0 | ((t1 << 8) & 0xFFF); + r->coeffs[2 * i + 1] = (t1 >> 4) | (t2 << 4); + } + + /* Note that the coefficients are not canonical */ + POLY_UBOUND(r, 4096); +} +#else /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +{ + poly_frombytes_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLY_FROMBYTES */ + +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +{ + unsigned i; +#if (MLKEM_INDCPA_MSGBYTES != MLKEM_N / 8) +#error "MLKEM_INDCPA_MSGBYTES must be equal to MLKEM_N/8 bytes!" +#endif + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8) + invariant(array_bound(r->coeffs, 0, 8 * i, 0, MLKEM_Q))) + { + unsigned j; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i < MLKEM_N / 8 && j >= 0 && j <= 8) + invariant(array_bound(r->coeffs, 0, 8 * i + j, 0, MLKEM_Q))) + { + /* Prevent the compiler from recognizing this as a bit selection */ + uint8_t mask = value_barrier_u8(1u << j); + r->coeffs[8 * i + j] = ct_sel_int16(HALF_Q, 0, msg[i] & mask); + } + } + POLY_BOUND_MSG(r, MLKEM_Q, "poly_frommsg output"); +} + +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *a) +{ + unsigned i; + POLY_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_N / 8; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 8)) + { + unsigned j; + msg[i] = 0; + for (j = 0; j < 8; j++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N / 8 && j >= 0 && j <= 8)) + { + uint32_t t = scalar_compress_d1(a->coeffs[8 * i + j]); + msg[i] |= t << j; + } + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +{ + ALIGN uint8_t buf0[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf1[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf3[MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t extkey0[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey1[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey2[MLKEM_SYMBYTES + 1]; + ALIGN uint8_t extkey3[MLKEM_SYMBYTES + 1]; + memcpy(extkey0, seed, MLKEM_SYMBYTES); + memcpy(extkey1, seed, MLKEM_SYMBYTES); + memcpy(extkey2, seed, MLKEM_SYMBYTES); + memcpy(extkey3, seed, MLKEM_SYMBYTES); + extkey0[MLKEM_SYMBYTES] = nonce0; + extkey1[MLKEM_SYMBYTES] = nonce1; + extkey2[MLKEM_SYMBYTES] = nonce2; + extkey3[MLKEM_SYMBYTES] = nonce3; + prf_eta1_x4(buf0, buf1, buf2, buf3, extkey0, extkey1, extkey2, extkey3); + poly_cbd_eta1(r0, buf0); + poly_cbd_eta1(r1, buf1); + poly_cbd_eta1(r2, buf2); + poly_cbd_eta1(r3, buf3); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA1 + 1, "poly_getnoise_eta1_4x output 3"); +} + +#if MLKEM_K == 2 || MLKEM_K == 4 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +{ + ALIGN uint8_t buf[MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[MLKEM_SYMBYTES + 1]; + + memcpy(extkey, seed, MLKEM_SYMBYTES); + extkey[MLKEM_SYMBYTES] = nonce; + prf_eta2(buf, extkey); + + poly_cbd_eta2(r, buf); + + POLY_BOUND_MSG(r, MLKEM_ETA1 + 1, "poly_getnoise_eta2 output"); +} +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +{ + ALIGN uint8_t buf1[KECCAK_WAY / 2][MLKEM_ETA1 * MLKEM_N / 4]; + ALIGN uint8_t buf2[KECCAK_WAY / 2][MLKEM_ETA2 * MLKEM_N / 4]; + ALIGN uint8_t extkey[KECCAK_WAY][MLKEM_SYMBYTES + 1]; + memcpy(extkey[0], seed, MLKEM_SYMBYTES); + memcpy(extkey[1], seed, MLKEM_SYMBYTES); + memcpy(extkey[2], seed, MLKEM_SYMBYTES); + memcpy(extkey[3], seed, MLKEM_SYMBYTES); + extkey[0][MLKEM_SYMBYTES] = nonce0; + extkey[1][MLKEM_SYMBYTES] = nonce1; + extkey[2][MLKEM_SYMBYTES] = nonce2; + extkey[3][MLKEM_SYMBYTES] = nonce3; + + prf_eta1(buf1[0], extkey[0]); + prf_eta1(buf1[1], extkey[1]); + prf_eta2(buf2[0], extkey[2]); + prf_eta2(buf2[1], extkey[3]); + + poly_cbd_eta1(r0, buf1[0]); + poly_cbd_eta1(r1, buf1[1]); + poly_cbd_eta2(r2, buf2[0]); + poly_cbd_eta2(r3, buf2[1]); + + POLY_BOUND_MSG(r0, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 0"); + POLY_BOUND_MSG(r1, MLKEM_ETA1 + 1, "poly_getnoise_eta1122_4x output 1"); + POLY_BOUND_MSG(r2, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 2"); + POLY_BOUND_MSG(r3, MLKEM_ETA2 + 1, "poly_getnoise_eta1122_4x output 3"); +} +#endif /* MLKEM_K == 2 */ + +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +{ + unsigned i; + POLY_BOUND(b_cache, 4096); + + for (i = 0; i < MLKEM_N / 4; i++) + __loop__( + assigns(i, object_whole(r)) + invariant(i >= 0 && i <= MLKEM_N / 4) + invariant(array_abs_bound(r->coeffs, 0, 4 * i, 2 * MLKEM_Q))) + { + basemul_cached(&r->coeffs[4 * i], &a->coeffs[4 * i], &b->coeffs[4 * i], + b_cache->coeffs[2 * i]); + basemul_cached(&r->coeffs[4 * i + 2], &a->coeffs[4 * i + 2], + &b->coeffs[4 * i + 2], b_cache->coeffs[2 * i + 1]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_TOMONT) +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + unsigned i; + const int16_t f = (1ULL << 32) % MLKEM_Q; /* 1353 */ + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_abs_bound(r->coeffs ,0, i, MLKEM_Q))) + { + r->coeffs[i] = fqmul(r->coeffs[i], f); + } + + POLY_BOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_TOMONT */ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +{ + poly_tomont_native(r); + POLY_BOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_TOMONT */ + +#if !defined(MLKEM_USE_NATIVE_POLY_REDUCE) +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(array_bound(r->coeffs, 0, i, 0, MLKEM_Q))) + { + /* Barrett reduction, giving signed canonical representative */ + int16_t t = barrett_reduce(r->coeffs[i]); + /* Conditional addition to get unsigned canonical representative */ + r->coeffs[i] = scalar_signed_to_unsigned_q(t); + } + + POLY_UBOUND(r, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_REDUCE */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +{ + poly_reduce_native(r); + POLY_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLY_REDUCE */ + +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] + b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] + b->coeffs[i]; + } +} + +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +{ + unsigned i; + for (i = 0; i < MLKEM_N; i++) + __loop__( + invariant(i >= 0 && i <= MLKEM_N) + invariant(forall(k0, i, MLKEM_N, r->coeffs[k0] == loop_entry(*r).coeffs[k0])) + invariant(forall(k1, 0, i, r->coeffs[k1] == loop_entry(*r).coeffs[k1] - b->coeffs[k1]))) + { + r->coeffs[i] = r->coeffs[i] - b->coeffs[i]; + } +} + +#if !defined(MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE) +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + unsigned i; + for (i = 0; i < MLKEM_N / 4; i++) + __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) + { + x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); + x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + } + POLY_BOUND(x, MLKEM_Q); +} +#else /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +{ + poly_mulcache_compute_native(x, a); + /* Omitting POLY_BOUND(x, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ +} +#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.h new file mode 100644 index 0000000000..1e8c109c6e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/poly.h @@ -0,0 +1,805 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLY_H +#define POLY_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "reduce.h" +#include "verify.h" + +/* Absolute exclusive upper bound for the output of the inverse NTT */ +#define INVNTT_BOUND (8 * MLKEM_Q) + +/* Absolute exclusive upper bound for the output of the forward NTT */ +#define NTT_BOUND (8 * MLKEM_Q) + +/* + * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial + * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] + */ +#define poly MLKEM_NAMESPACE(poly) +typedef struct +{ + int16_t coeffs[MLKEM_N]; +} ALIGN poly; + +/* + * INTERNAL presentation of precomputed data speeding up + * the base multiplication of two polynomials in NTT domain. + */ +#define poly_mulcache MLKEM_NAMESPACE(poly_mulcache) +typedef struct +{ + int16_t coeffs[MLKEM_N >> 1]; +} poly_mulcache; + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define scalar_compress_d1 MLKEM_NAMESPACE(scalar_compress_d1) +#define scalar_compress_d4 MLKEM_NAMESPACE(scalar_compress_d4) +#define scalar_compress_d5 MLKEM_NAMESPACE(scalar_compress_d5) +#define scalar_compress_d10 MLKEM_NAMESPACE(scalar_compress_d10) +#define scalar_compress_d11 MLKEM_NAMESPACE(scalar_compress_d11) +#define scalar_decompress_d4 MLKEM_NAMESPACE(scalar_decompress_d4) +#define scalar_decompress_d5 MLKEM_NAMESPACE(scalar_decompress_d5) +#define scalar_decompress_d10 MLKEM_NAMESPACE(scalar_decompress_d10) +#define scalar_decompress_d11 MLKEM_NAMESPACE(scalar_decompress_d11) +#define scalar_signed_to_unsigned_q MLKEM_NAMESPACE(scalar_signed_to_unsigned_q) +/* End of static namespacing */ + +/************************************************************ + * Name: scalar_compress_d1 + * + * Description: Computes round(u * 2 / q) + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 1. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d1(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 2) + ensures(return_value == (((uint32_t)u * 2 + MLKEM_Q / 2) / MLKEM_Q) % 2) ) +{ + uint32_t d0 = u << 1; + d0 *= 645083; + d0 += 1u << 30; + d0 >>= 31; + return d0; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_compress_d4 + * + * Description: Computes round(u * 16 / q) % 16 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d4(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 16) + ensures(return_value == (((uint32_t)u * 16 + MLKEM_Q / 2) / MLKEM_Q) % 16)) +{ + uint32_t d0 = (uint32_t)u * 1290160; /* 16 * round(2^28 / MLKEM_Q) */ + return (d0 + (1u << 27)) >> 28; /* round(d0/2^28) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d4 + * + * Description: Computes round(u * q / 16) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 4. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d4(uint32_t u) +__contract__( + requires(0 <= u && u < 16) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 8) / 16; } + +/************************************************************ + * Name: scalar_compress_d5 + * + * Description: Computes round(u * 32 / q) % 32 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d5(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < 32) + ensures(return_value == (((uint32_t)u * 32 + MLKEM_Q / 2) / MLKEM_Q) % 32) ) +{ + uint32_t d0 = (uint32_t)u * 1290176; /* 2^5 * round(2^27 / MLKEM_Q) */ + return (d0 + (1u << 26)) >> 27; /* round(d0/2^27) */ +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d5 + * + * Description: Computes round(u * q / 32) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 5. + * + * Arguments: - u: Unsigned canonical modulus modulo 32 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d5(uint32_t u) +__contract__( + requires(0 <= u && u < 32) + ensures(return_value <= MLKEM_Q - 1) +) { return ((u * MLKEM_Q) + 16) / 32; } + +/************************************************************ + * Name: scalar_compress_d10 + * + * Description: Computes round(u * 2**10 / q) % 2**10 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d10(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 10)) + ensures(return_value == (((uint32_t)u * (1u << 10) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 10))) +{ + uint64_t d0 = (uint64_t)u * 2642263040; /* 2^10 * round(2^32 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x3FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d10 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d10(uint32_t u) +__contract__( + requires(0 <= u && u < 1024) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 512) / 1024; } + +/************************************************************ + * Name: scalar_compress_d11 + * + * Description: Computes round(u * 2**11 / q) % 2**11 + * + * Implements Compress_d from FIPS203, Eq (4.7), + * for d = 11. + * + * Arguments: - u: Unsigned canonical modulus modulo q + * to be compressed. + ************************************************************/ +/* + * The multiplication in this routine will exceed UINT32_MAX + * and wrap around for large values of u. This is expected and required. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif +static INLINE uint32_t scalar_compress_d11(uint16_t u) +__contract__( + requires(u <= MLKEM_Q - 1) + ensures(return_value < (1u << 11)) + ensures(return_value == (((uint32_t)u * (1u << 11) + MLKEM_Q / 2) / MLKEM_Q) % (1 << 11))) +{ + uint64_t d0 = (uint64_t)u * 5284526080; /* 2^11 * round(2^33 / MLKEM_Q) */ + d0 = (d0 + ((uint64_t)1u << 32)) >> 33; + return (d0 & 0x7FF); +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************************ + * Name: scalar_decompress_d11 + * + * Description: Computes round(u * q / 1024) + * + * Implements Decompress_d from FIPS203, Eq (4.8), + * for d = 10. + * + * Arguments: - u: Unsigned canonical modulus modulo 16 + * to be decompressed. + ************************************************************/ +static INLINE uint16_t scalar_decompress_d11(uint32_t u) +__contract__( + requires(0 <= u && u < 2048) + ensures(return_value <= (MLKEM_Q - 1)) +) { return ((u * MLKEM_Q) + 1024) / 2048; } + +/************************************************************ + * Name: scalar_signed_to_unsigned_q + * + * Description: converts signed polynomial coefficient + * from signed (-3328 .. 3328) form to + * unsigned form (0 .. 3328). + * + * Note: Cryptographic constant time implementation + * + * Examples: 0 -> 0 + * 1 -> 1 + * 3328 -> 3328 + * -1 -> 3328 + * -2 -> 3327 + * -3328 -> 1 + * + * Arguments: c: signed coefficient to be converted + ************************************************************/ +static INLINE uint16_t scalar_signed_to_unsigned_q(int16_t c) +__contract__( + requires(c >= -(MLKEM_Q - 1) && c <= (MLKEM_Q - 1)) + ensures(return_value >= 0 && return_value <= (MLKEM_Q - 1)) + ensures(return_value == (int32_t)c + (((int32_t)c < 0) * MLKEM_Q))) +{ + /* Add Q if c is negative, but in constant time */ + c = ct_sel_int16(c + MLKEM_Q, c, ct_cmask_neg_i16(c)); + + cassert(c >= 0, "scalar_signed_to_unsigned_q result lower bound"); + cassert(c < MLKEM_Q, "scalar_signed_to_unsigned_q result upper bound"); + + /* and therefore cast to uint16_t is safe. */ + return (uint16_t)c; +} + +#define poly_compress_du MLKEM_NAMESPACE(poly_compress_du) +/************************************************* + * Name: poly_compress_du + * + * Description: Compression (du bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_du(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DU], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(memory_slice(r, MLKEM_POLYCOMPRESSEDBYTES_DU)) +); + +#define poly_decompress_du MLKEM_NAMESPACE(poly_decompress_du) +/************************************************* + * Name: poly_decompress_du + * + * Description: De-serialization and subsequent decompression (du bits) of a + *polynomial; approximate inverse of poly_compress_du + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_du(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_compress_dv MLKEM_NAMESPACE(poly_compress_dv) +/************************************************* + * Name: poly_compress_dv + * + * Description: Compression (dv bits) and subsequent serialization of a + *polynomial + * + * Arguments: - uint8_t *r: pointer to output byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV) + * - const poly *a: pointer to input polynomial + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_compress_dv(uint8_t r[MLKEM_POLYCOMPRESSEDBYTES_DV], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + +#define poly_decompress_dv MLKEM_NAMESPACE(poly_decompress_dv) +/************************************************* + * Name: poly_decompress_dv + * + * Description: De-serialization and subsequent decompression (dv bits) of a + *polynomial; approximate inverse of poly_compress + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYCOMPRESSEDBYTES_DV + *bytes) + * + * Upon return, the coefficients of the output polynomial are unsigned-canonical + * (non-negative and smaller than MLKEM_Q). + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_decompress_dv(poly *r, const uint8_t a[MLKEM_POLYCOMPRESSEDBYTES_DV]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYCOMPRESSEDBYTES_DV)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tobytes MLKEM_NAMESPACE(poly_tobytes) +/************************************************* + * Name: poly_tobytes + * + * Description: Serialization of a polynomial. + * Signed coefficients are converted to + * unsigned form before serialization. + * + * Arguments: INPUT: + * - a: const pointer to input polynomial, + * with each coefficient in the range [0,1,..,Q-1] + * OUTPUT + * - r: pointer to output byte array + * (of MLKEM_POLYBYTES bytes) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tobytes(uint8_t r[MLKEM_POLYBYTES], const poly *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYBYTES)) + requires(memory_no_alias(a, sizeof(poly))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(r)) +); + + +#define poly_frombytes MLKEM_NAMESPACE(poly_frombytes) +/************************************************* + * Name: poly_frombytes + * + * Description: De-serialization of a polynomial. + * + * Arguments: INPUT + * - a: pointer to input byte array + * (of MLKEM_POLYBYTES bytes) + * OUTPUT + * - r: pointer to output polynomial, with + * each coefficient unsigned and in the range + * 0 .. 4095 + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frombytes(poly *r, const uint8_t a[MLKEM_POLYBYTES]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) +); + + +#define poly_frommsg MLKEM_NAMESPACE(poly_frommsg) +/************************************************* + * Name: poly_frommsg + * + * Description: Convert 32-byte message to polynomial + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *msg: pointer to input message + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_frommsg(poly *r, const uint8_t msg[MLKEM_INDCPA_MSGBYTES]) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + assigns(object_whole(r)) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_tomsg MLKEM_NAMESPACE(poly_tomsg) +/************************************************* + * Name: poly_tomsg + * + * Description: Convert polynomial to 32-byte message + * + * Arguments: - uint8_t *msg: pointer to output message + * - const poly *r: pointer to input polynomial + * Coefficients must be unsigned canonical + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomsg(uint8_t msg[MLKEM_INDCPA_MSGBYTES], const poly *r) +__contract__( + requires(memory_no_alias(msg, MLKEM_INDCPA_MSGBYTES)) + requires(memory_no_alias(r, sizeof(poly))) + requires(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) + assigns(object_whole(msg)) +); + +#define poly_getnoise_eta1_4x MLKEM_NAMESPACE(poly_getnoise_eta1_4x) +/************************************************* + * Name: poly_getnoise_eta1_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and nonces, with output polynomials close to centered binomial distribution + * with parameter MLKEM_ETA1. + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], uint8_t nonce0, + uint8_t nonce1, uint8_t nonce2, uint8_t nonce3) +/* Depending on MLKEM_K, the pointers passed to this function belong + to the same objects, so we cannot use memory_no_alias for r0-r3. + + NOTE: Somehow it is important to use memory_no_alias() first in the + conjunctions defining each case. +*/ +#if MLKEM_K == 2 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case A: r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 4 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case B: r0, r1, r2, r3 consecutive */ + (memory_no_alias(r0, 4 * sizeof(poly)) && r1 == r0 + 1 && r2 == r0 + 2 && r3 == r0 + 3)) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#elif MLKEM_K == 3 +__contract__( + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + requires( /* Case C: r0, r1, r2 consecutive */ + (memory_no_alias(r0, 3 * sizeof(poly)) && memory_no_alias(r3, 1 * sizeof(poly)) && + r1 == r0 + 1 && r2 == r0 + 2 && !same_object(r3, r0))) + assigns(memory_slice(r0, sizeof(poly))) + assigns(memory_slice(r1, sizeof(poly))) + assigns(memory_slice(r2, sizeof(poly))) + assigns(memory_slice(r3, sizeof(poly))) + ensures( + array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1)); +); +#endif /* MLKEM_K */ + +#if MLKEM_ETA1 == MLKEM_ETA2 +/* + * We only require poly_getnoise_eta2_4x for ml-kem-768 and ml-kem-1024 + * where MLKEM_ETA2 = MLKEM_ETA1 = 2. + * For ml-kem-512, poly_getnoise_eta1122_4x is used instead. + */ +#define poly_getnoise_eta2_4x poly_getnoise_eta1_4x +#endif /* MLKEM_ETA1 == MLKEM_ETA2 */ + +#if MLKEM_K == 2 || MLKEM_K == 4 +#define poly_getnoise_eta2 MLKEM_NAMESPACE(poly_getnoise_eta2) +/************************************************* + * Name: poly_getnoise_eta2 + * + * Description: Sample a polynomial deterministically from a seed and a nonce, + * with output polynomial close to centered binomial distribution + * with parameter MLKEM_ETA2 + * + * Arguments: - poly *r: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta2(poly *r, const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_ETA2 + 1)) +); +#endif /* MLKEM_K == 2 || MLKEM_K == 4 */ + +#if MLKEM_K == 2 +#define poly_getnoise_eta1122_4x MLKEM_NAMESPACE(poly_getnoise_eta1122_4x) +/************************************************* + * Name: poly_getnoise_eta1122_4x + * + * Description: Batch sample four polynomials deterministically from a seed + * and a nonces, with output polynomials close to centered binomial + * distribution with parameter MLKEM_ETA1 and MLKEM_ETA2 + * + * Arguments: - poly *r{0,1,2,3}: pointer to output polynomial + * - const uint8_t *seed: pointer to input seed + * (of length MLKEM_SYMBYTES bytes) + * - uint8_t nonce{0,1,2,3}: one-byte input nonce + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_getnoise_eta1122_4x(poly *r0, poly *r1, poly *r2, poly *r3, + const uint8_t seed[MLKEM_SYMBYTES], + uint8_t nonce0, uint8_t nonce1, uint8_t nonce2, + uint8_t nonce3) +__contract__( + requires( /* r0, r1 consecutive, r2, r3 consecutive */ + (memory_no_alias(r0, 2 * sizeof(poly)) && memory_no_alias(r2, 2 * sizeof(poly)) && + r1 == r0 + 1 && r3 == r2 + 1 && !same_object(r0, r2))) + requires(memory_no_alias(seed, MLKEM_SYMBYTES)) + assigns(object_whole(r0), object_whole(r1), object_whole(r2), object_whole(r3)) + ensures(array_abs_bound(r0->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r1->coeffs,0, MLKEM_N, MLKEM_ETA1 + 1) + && array_abs_bound(r2->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1) + && array_abs_bound(r3->coeffs,0, MLKEM_N, MLKEM_ETA2 + 1)); +); +#endif /* MLKEM_K == 2 */ + +#define poly_basemul_montgomery_cached \ + MLKEM_NAMESPACE(poly_basemul_montgomery_cached) +/************************************************* + * Name: poly_basemul_montgomery_cached + * + * Description: Multiplication of two polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < q in absolute value. + * + * The result is coefficient-wise bound by 3/2 q in absolute + * value. + * + * Arguments: - poly *r: pointer to output polynomial + * - const poly *a: pointer to first input polynomial + * - const poly *b: pointer to second input polynomial + * - const poly_mulcache *b_cache: pointer to mulcache + * for second input polynomial. Can be computed + * via poly_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_basemul_montgomery_cached(poly *r, const poly *a, const poly *b, + const poly_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(memory_no_alias(b_cache, sizeof(poly_mulcache))) + requires(array_bound(a->coeffs, 0, MLKEM_N, 0, UINT12_LIMIT)) + assigns(object_whole(r)) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, 2 * MLKEM_Q)) +); + +#define poly_tomont MLKEM_NAMESPACE(poly_tomont) +/************************************************* + * Name: poly_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void poly_tomont(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_Q)) +); + +#define poly_mulcache_compute MLKEM_NAMESPACE(poly_mulcache_compute) +/************************************************************ + * Name: poly_mulcache_compute + * + * Description: Computes the mulcache for a polynomial in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_mulcache_compute(poly_mulcache *x, const poly *a) +__contract__( + requires(memory_no_alias(x, sizeof(poly_mulcache))) + requires(memory_no_alias(a, sizeof(poly))) + assigns(object_whole(x)) +); + +#define poly_reduce MLKEM_NAMESPACE(poly_reduce) +/************************************************* + * Name: poly_reduce + * + * Description: Converts polynomial to _unsigned canonical_ representatives. + * + * The input coefficients can be arbitrary integers in int16_t. + * The output coefficients are in [0,1,...,MLKEM_Q-1]. + * + * Arguments: - poly *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of poly_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_reduce(poly *r) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + assigns(memory_slice(r, sizeof(poly))) + ensures(array_bound(r->coeffs, 0, MLKEM_N, 0, MLKEM_Q)) +); + +#define poly_add MLKEM_NAMESPACE(poly_add) +/************************************************************ + * Name: poly_add + * + * Description: Adds two polynomials in place + * + * Arguments: - r: Pointer to input-output polynomial to be added to. + * - b: Pointer to input polynomial that should be added + * to r. Must be disjoint from r. + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + ************************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_add. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_add(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] + b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] + b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] + b->coeffs[k])) + assigns(memory_slice(r, sizeof(poly))) +); + +#define poly_sub MLKEM_NAMESPACE(poly_sub) +/************************************************* + * Name: poly_sub + * + * Description: Subtract two polynomials; no modular reduction is performed + * + * Arguments: - poly *r: Pointer to input-output polynomial to be added + *to. + * - const poly *b: Pointer to second input polynomial + **************************************************/ +/* + * NOTE: The reference implementation uses a 3-argument poly_sub. + * We specialize to the accumulator form to avoid reasoning about aliasing. + */ +MLKEM_NATIVE_INTERNAL_API +void poly_sub(poly *r, const poly *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(b, sizeof(poly))) + requires(forall(k0, 0, MLKEM_N, (int32_t) r->coeffs[k0] - b->coeffs[k0] <= INT16_MAX)) + requires(forall(k1, 0, MLKEM_N, (int32_t) r->coeffs[k1] - b->coeffs[k1] >= INT16_MIN)) + ensures(forall(k, 0, MLKEM_N, r->coeffs[k] == old(*r).coeffs[k] - b->coeffs[k])) + assigns(object_whole(r)) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.c new file mode 100644 index 0000000000..7d20167731 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.c @@ -0,0 +1,172 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "polyvec.h" +#include +#include "arith_backend.h" +#include "ntt.h" +#include "poly.h" + +#include "debug/debug.h" + +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +{ + unsigned i; + POLYVEC_UBOUND(a, MLKEM_Q); + + for (i = 0; i < MLKEM_K; i++) + { + poly_compress_du(r + i * MLKEM_POLYCOMPRESSEDBYTES_DU, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_decompress_du(&r->vec[i], a + i * MLKEM_POLYCOMPRESSEDBYTES_DU); + } + + POLYVEC_UBOUND(r, MLKEM_Q); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tobytes(r + i * MLKEM_POLYBYTES, &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_frombytes(&r->vec[i], a + i * MLKEM_POLYBYTES); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_ntt(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_invntt_tomont(&r->vec[i]); + } +} + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED) +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + POLYVEC_BOUND(b_cache, MLKEM_Q); + + poly_basemul_montgomery_cached(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_cached(&t, &a->vec[i], &b->vec[i], + &b_cache->vec[i]); + poly_add(r, &t); + /* abs bounds: < (i+1) * 3/2 * q */ + } + + /* + * Those bounds are true for the C implementation, but not needed + * in the higher level bounds reasoning. It is thus best to omit + * them from the spec to not unnecessarily constraint native implementations. + */ + cassert(array_abs_bound(r->coeffs, 0, MLKEM_N, MLKEM_K * 2 * MLKEM_Q), + "polyvec_basemul_acc_montgomery_cached output bounds"); + /* TODO: Integrate CBMC assertion into POLY_BOUND if CBMC is set */ + POLY_BOUND(r, MLKEM_K * 2 * MLKEM_Q); +} +#else /* !MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + POLYVEC_BOUND(a, 4096); + POLYVEC_BOUND(b, NTT_BOUND); + /* Omitting POLYVEC_BOUND(b_cache, MLKEM_Q) since native implementations may + * decide not to use a mulcache. Note that the C backend implementation + * of poly_basemul_montgomery_cached() does still include the check. */ + polyvec_basemul_acc_montgomery_cached_native(r, a, b, b_cache); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED */ + +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +{ + polyvec_mulcache b_cache; + polyvec_mulcache_compute(&b_cache, b); + polyvec_basemul_acc_montgomery_cached(r, a, b, &b_cache); +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_mulcache_compute(&x->vec[i], &a->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_reduce(&r->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_add(&r->vec[i], &b->vec[i]); + } +} + +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +{ + unsigned i; + for (i = 0; i < MLKEM_K; i++) + { + poly_tomont(&r->vec[i]); + } +} diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.h new file mode 100644 index 0000000000..1387241502 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/polyvec.h @@ -0,0 +1,332 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef POLYVEC_H +#define POLYVEC_H + +#include +#include "common.h" +#include "poly.h" + +#define polyvec MLKEM_NAMESPACE(polyvec) +typedef struct +{ + poly vec[MLKEM_K]; +} ALIGN polyvec; + +#define polyvec_mulcache MLKEM_NAMESPACE(polyvec_mulcache) +typedef struct +{ + poly_mulcache vec[MLKEM_K]; +} polyvec_mulcache; + +#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du) +/************************************************* + * Name: polyvec_compress_du + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES_DU) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_compress_du(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES_DU], + const polyvec *a) +__contract__( + requires(memory_no_alias(r, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_decompress_du MLKEM_NAMESPACE(polyvec_decompress_du) +/************************************************* + * Name: polyvec_decompress_du + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress_du + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES_DU) + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_decompress_du(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES_DU]) +__contract__( + requires(memory_no_alias(a, MLKEM_POLYVECCOMPRESSEDBYTES_DU)) + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_tobytes MLKEM_NAMESPACE(polyvec_tobytes) +/************************************************* + * Name: polyvec_tobytes + * + * Description: Serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECBYTES) + * - const polyvec *a: pointer to input vector of polynomials + * Each polynomial must have coefficients in [0,..,q-1]. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) +__contract__( + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(r, MLKEM_POLYVECBYTES)) + requires(forall(k0, 0, MLKEM_K, + array_bound(a->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) + assigns(object_whole(r)) +); + +#define polyvec_frombytes MLKEM_NAMESPACE(polyvec_frombytes) +/************************************************* + * Name: polyvec_frombytes + * + * Description: De-serialize vector of polynomials; + * inverse of polyvec_tobytes + * + * Arguments: - const polyvec *a: pointer to output vector of polynomials + * (of length MLKEM_POLYVECBYTES). Output will have coefficients + * normalized in [0..4095]. + * - uint8_t *r: pointer to input byte array + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_frombytes(polyvec *r, const uint8_t a[MLKEM_POLYVECBYTES]) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(a, MLKEM_POLYVECBYTES)) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) +); + +#define polyvec_ntt MLKEM_NAMESPACE(polyvec_ntt) +/************************************************* + * Name: polyvec_ntt + * + * Description: Apply forward NTT to all elements of a vector of polynomials. + * + * The input is assumed to be in normal order and + * coefficient-wise bound by MLKEM_Q in absolute value. + * + * The output polynomial is in bitreversed order, and + * coefficient-wise bound by NTT_BOUND in absolute value. + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_ntt(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, NTT_BOUND))) +); + +#define polyvec_invntt_tomont MLKEM_NAMESPACE(polyvec_invntt_tomont) +/************************************************* + * Name: polyvec_invntt_tomont + * + * Description: Apply inverse NTT to all elements of a vector of polynomials + * and multiply by Montgomery factor 2^16 + * + * The input is assumed to be in bitreversed order, and can + * have arbitrary coefficients in int16_t. + * + * The output polynomial is in normal order, and + * coefficient-wise bound by INVNTT_BOUND in absolute value. + * + * + * Arguments: - polyvec *r: pointer to in/output vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_invntt_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, INVNTT_BOUND))) +); + +#define polyvec_basemul_acc_montgomery \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery) +/************************************************* + * Name: polyvec_basemul_acc_montgomery + * + * Description: Multiply elements of a and b in NTT domain, accumulate into r, + * and multiply by 2^-16. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input vector of polynomials + * - const polyvec *b: pointer to second input vector of polynomials + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + + +#define polyvec_basemul_acc_montgomery_cached \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached) +/************************************************* + * Name: polyvec_basemul_acc_montgomery_cached + * + * Description: Scalar product of two vectors of polynomials in NTT domain, + * using mulcache for second operand. + * + * Bounds: + * - a is assumed to be coefficient-wise < 4096 in absolute value. + * - No bounds guarantees for the coefficients in the result. + * + * Arguments: - poly *r: pointer to output polynomial + * - const polyvec *a: pointer to first input polynomial vector + * - const polyvec *b: pointer to second input polynomial vector + * - const polyvec_mulcache *b_cache: pointer to mulcache + * for second input polynomial vector. Can be computed + * via polyvec_mulcache_compute(). + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_basemul_acc_montgomery_cached(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +__contract__( + requires(memory_no_alias(r, sizeof(poly))) + requires(memory_no_alias(a, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(memory_no_alias(b_cache, sizeof(polyvec_mulcache))) + requires(forall(k1, 0, MLKEM_K, + array_bound(a->vec[k1].coeffs, 0, MLKEM_N, 0, UINT12_LIMIT))) + assigns(memory_slice(r, sizeof(poly))) +); + +#define polyvec_mulcache_compute MLKEM_NAMESPACE(polyvec_mulcache_compute) +/************************************************************ + * Name: polyvec_mulcache_compute + * + * Description: Computes the mulcache for a vector of polynomials in NTT domain + * + * The mulcache of a degree-2 polynomial b := b0 + b1*X + * in Fq[X]/(X^2-zeta) is the value b1*zeta, needed when + * computing products of b in Fq[X]/(X^2-zeta). + * + * The mulcache of a polynomial in NTT domain -- which is + * a 128-tuple of degree-2 polynomials in Fq[X]/(X^2-zeta), + * for varying zeta, is the 128-tuple of mulcaches of those + * polynomials. + * + * The mulcache of a vector of polynomials is the vector + * of mulcaches of its entries. + * + * Arguments: - x: Pointer to mulcache to be populated + * - a: Pointer to input polynomial vector + ************************************************************/ +/* + * NOTE: The default C implementation of this function populates + * the mulcache with values in (-q,q), but this is not needed for the + * higher level safety proofs, and thus not part of the spec. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_mulcache_compute(polyvec_mulcache *x, const polyvec *a) +__contract__( + requires(memory_no_alias(x, sizeof(polyvec_mulcache))) + requires(memory_no_alias(a, sizeof(polyvec))) + assigns(object_whole(x)) +); + +#define polyvec_reduce MLKEM_NAMESPACE(polyvec_reduce) +/************************************************* + * Name: polyvec_reduce + * + * Description: Applies Barrett reduction to each coefficient + * of each element of a vector of polynomials; + * for details of the Barrett reduction see comments in reduce.c + * + * Arguments: - polyvec *r: pointer to input/output polynomial + **************************************************/ +/* + * NOTE: The semantics of polyvec_reduce() is different in + * the reference implementation, which requires + * signed canonical output data. Unsigned canonical + * outputs are better suited to the only remaining + * use of poly_reduce() in the context of (de)serialization. + */ +MLKEM_NATIVE_INTERNAL_API +void polyvec_reduce(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(k0, 0, MLKEM_K, + array_bound(r->vec[k0].coeffs, 0, MLKEM_N, 0, MLKEM_Q))) +); + +#define polyvec_add MLKEM_NAMESPACE(polyvec_add) +/************************************************* + * Name: polyvec_add + * + * Description: Add vectors of polynomials + * + * Arguments: - polyvec *r: pointer to input-output vector of polynomials to be + * added to + * - const polyvec *b: pointer to second input vector of polynomials + * + * The coefficients of r and b must be so that the addition does + * not overflow. Otherwise, the behaviour of this function is undefined. + * + * The coefficients returned in *r are in int16_t which is sufficient + * to prove type-safety of calling units. Therefore, no stronger + * ensures clause is required on this function. + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_add(polyvec *r, const polyvec *b) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + requires(memory_no_alias(b, sizeof(polyvec))) + requires(forall(j0, 0, MLKEM_K, + forall(k0, 0, MLKEM_N, + (int32_t)r->vec[j0].coeffs[k0] + b->vec[j0].coeffs[k0] <= INT16_MAX))) + requires(forall(j1, 0, MLKEM_K, + forall(k1, 0, MLKEM_N, + (int32_t)r->vec[j1].coeffs[k1] + b->vec[j1].coeffs[k1] >= INT16_MIN))) + assigns(object_whole(r)) +); + +#define polyvec_tomont MLKEM_NAMESPACE(polyvec_tomont) +/************************************************* + * Name: polyvec_tomont + * + * Description: Inplace conversion of all coefficients of a polynomial + * vector from normal domain to Montgomery domain + * + * Bounds: Output < q in absolute value. + * + **************************************************/ +MLKEM_NATIVE_INTERNAL_API +void polyvec_tomont(polyvec *r) +__contract__( + requires(memory_no_alias(r, sizeof(polyvec))) + assigns(memory_slice(r, sizeof(polyvec))) + assigns(object_whole(r)) + ensures(forall(j, 0, MLKEM_K, + array_abs_bound(r->vec[j].coeffs, 0, MLKEM_N, MLKEM_Q))) +); + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/reduce.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/reduce.h new file mode 100644 index 0000000000..1f502167eb --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/reduce.h @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REDUCE_H +#define REDUCE_H + +#include +#include "cbmc.h" +#include "common.h" +#include "debug/debug.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define cast_uint16_to_int16 MLKEM_NAMESPACE(cast_uint16_to_int16) +#define montgomery_reduce_generic MLKEM_NAMESPACE(montgomery_reduce_generic) +#define montgomery_reduce MLKEM_NAMESPACE(montgomery_reduce) +#define fqmul MLKEM_NAMESPACE(fqmul) +#define barrett_reduce MLKEM_NAMESPACE(barrett_reduce) +/* End of static namespacing */ + +#define HALF_Q ((MLKEM_Q + 1) / 2) /* 1665 */ + +/************************************************* + * Name: cast_uint16_to_int16 + * + * Description: Cast uint16 value to int16 + * + * Returns: + * input x in 0 .. 32767: returns value unchanged + * input x in 32768 .. 65535: returns (x - 65536) + **************************************************/ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif +ALWAYS_INLINE +static INLINE int16_t cast_uint16_to_int16(uint16_t x) +{ + /* + * PORTABILITY: This relies on uint16_t -> int16_t + * being implemented as the inverse of int16_t -> uint16_t, + * which is implementation-defined (C99 6.3.1.3 (3)) + * CBMC (correctly) fails to prove this conversion is OK, + * so we have to suppress that check here + */ + return (int16_t)x; +} +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: montgomery_reduce_generic + * + * Description: Generic Montgomery reduction; given a 32-bit integer a, computes + * 16-bit integer congruent to a * R^-1 mod q, where R=2^16 + * + * Arguments: - int32_t a: input integer to be reduced + * + * Returns: integer congruent to a * R^-1 modulo q, with absolute value + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1)/2 + * + **************************************************/ +ALWAYS_INLINE +static INLINE int16_t montgomery_reduce_generic(int32_t a) +{ + /* QINV == -3327 converted to uint16_t == -3327 + 65536 == 62209 */ + const uint32_t QINV = 62209; /* q^-1 mod 2^16 */ + + /* Compute a*q^{-1} mod 2^16 in unsigned representatives */ + const uint16_t a_reduced = a & UINT16_MAX; + const uint16_t a_inverted = (a_reduced * QINV) & UINT16_MAX; + + /* Lift to signed canonical representative mod 2^16. */ + const int16_t t = cast_uint16_to_int16(a_inverted); + + int32_t r = a - ((int32_t)t * MLKEM_Q); + /* Bounds: |r| <= |a| + 2^15 * MLKEM_Q */ + + /* + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + r = r >> 16; + /* Bounds: |r >> 16| <= ceil(|r| / 2^16) + * <= ceil(|a| / 2^16 + MLKEM_Q / 2) + * <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * + * (Note that |a >> n| = ceil(|a| / 2^16) for negative a) + */ + + return (int16_t)r; +} + +/************************************************* + * Name: montgomery_reduce + * + * Description: Montgomery reduction + * + * Arguments: - int32_t a: input integer to be reduced + * Must be smaller than 2 * 2^12 * 2^15 in absolute value. + * + * Returns: integer congruent to a * R^-1 modulo q, + * smaller than 2 * q in absolute value. + **************************************************/ +static INLINE int16_t montgomery_reduce(int32_t a) +__contract__( + requires(a > -(2 * 4096 * 32768)) + requires(a < (2 * 4096 * 32768)) + ensures(return_value > -2 * MLKEM_Q && return_value < 2 * MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(a, 2 * UINT12_LIMIT * 32768, "montgomery_reduce input"); + + res = montgomery_reduce_generic(a); + /* Bounds: + * |res| <= ceil(|a| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2 * UINT12_LIMIT * 32768 / 65536) + (MLKEM_Q + 1) / 2 + * <= UINT12_LIMIT + (MLKEM_Q + 1) / 2 + * < 2 * MLKEM_Q */ + + SCALAR_BOUND(res, 2 * MLKEM_Q, "montgomery_reduce output"); + return res; +} + +/************************************************* + * Name: fqmul + * + * Description: Montgomery multiplication modulo q=3329 + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +static INLINE int16_t fqmul(int16_t a, int16_t b) +__contract__( + requires(b > -HALF_Q) + requires(b < HALF_Q) + ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) +) +{ + int16_t res; + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + res = montgomery_reduce((int32_t)a * (int32_t)b); + /* Bounds: + * |res| <= ceil(|a| * |b| / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil(2^15 * ((MLKEM_Q - 1)/2) / 2^16) + (MLKEM_Q + 1) / 2 + * <= ceil((MLKEM_Q - 1) / 4) + (MLKEM_Q + 1) / 2 + * < MLKEM_Q + */ + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; +} + +/************************************************* + * Name: barrett_reduce + * + * Description: Barrett reduction; given a 16-bit integer a, computes + * centered representative congruent to a mod q in + * {-(q-1)/2,...,(q-1)/2} + * + * Arguments: - int16_t a: input integer to be reduced + * + * Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. + **************************************************/ +static INLINE int16_t barrett_reduce(int16_t a) +__contract__( + ensures(return_value > -HALF_Q && return_value < HALF_Q) +) +{ + /* + * To divide by MLKEM_Q using Barrett multiplication, the "magic number" + * multiplier is round_to_nearest(2**26/MLKEM_Q) + */ + const int BPOWER = 26; + const int32_t barrett_multiplier = ((1 << BPOWER) + MLKEM_Q / 2) / MLKEM_Q; + + /* + * Compute round_to_nearest(a/MLKEM_Q) using the multiplier + * above and shift by BPOWER places. + * PORTABILITY: Right-shift on a signed integer is, strictly-speaking, + * implementation-defined for negative left argument. Here, + * we assume it's sign-preserving "arithmetic" shift right. (C99 6.5.7 (5)) + */ + const int32_t t = (barrett_multiplier * a + (1 << (BPOWER - 1))) >> BPOWER; + + /* + * t is in -10 .. +10, so we need 32-bit math to + * evaluate t * MLKEM_Q and the subsequent subtraction + */ + return (int16_t)(a - t * MLKEM_Q); +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.c new file mode 100644 index 0000000000..918986e9b2 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "rej_uniform.h" +#include "arith_backend.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define rej_uniform_scalar MLKEM_NAMESPACE(rej_uniform_scalar) +/* End of static namespacing */ + +/************************************************* + * Name: rej_uniform_scalar + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ +static unsigned int rej_uniform_scalar(int16_t *r, unsigned int target, + unsigned int offset, const uint8_t *buf, + unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + + ctr = offset; + pos = 0; + /* pos + 3 cannot overflow due to the assumption buflen <= 4096 */ + while (ctr < target && pos + 3 <= buflen) + __loop__( + invariant(offset <= ctr && ctr <= target && pos <= buflen) + invariant(ctr > 0 ==> array_bound(r, 0, ctr, 0, MLKEM_Q))) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)) & 0xFFF; + pos += 3; + + if (val0 < MLKEM_Q) + { + r[ctr++] = val0; + } + if (ctr < target && val1 < MLKEM_Q) + { + r[ctr++] = val1; + } + } + return ctr; +} + +#if !defined(MLKEM_USE_NATIVE_REJ_UNIFORM) +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#else /* MLKEM_USE_NATIVE_REJ_UNIFORM */ + +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +{ + int ret; + + /* Sample from large buffer with full lane as much as possible. */ + ret = rej_uniform_native(r + offset, target - offset, buf, buflen); + if (ret != -1) + return offset + (unsigned)ret; + + return rej_uniform_scalar(r, target, offset, buf, buflen); +} +#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.h new file mode 100644 index 0000000000..13db836bcc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/rej_uniform.h @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef REJ_UNIFORM_H +#define REJ_UNIFORM_H + +#include +#include +#include "cbmc.h" +#include "common.h" + +#define rej_uniform MLKEM_NAMESPACE(rej_uniform) +/************************************************* + * Name: rej_uniform + * + * Description: Run rejection sampling on uniform random bytes to generate + * uniform random integers mod q + * + * Arguments: - int16_t *r: pointer to output buffer + * - unsigned int target: requested number of 16-bit integers + * (uniform mod q). + * Must be <= 4096. + * - unsigned int offset: number of 16-bit integers that have + * already been sampled. + * Must be <= target. + * - const uint8_t *buf: pointer to input buffer + * (assumed to be uniform random bytes) + * - unsigned int buflen: length of input buffer in bytes + * Must be <= 4096. + * Must be a multiple of 3. + * + * Note: Strictly speaking, only a few values of buflen near UINT_MAX need + * excluding. The limit of 4096 is somewhat arbitary but sufficient for all + * uses of this function. Similarly, the actual limit for target is UINT_MAX/2. + * + * Returns the new offset of sampled 16-bit integers, at most target, + * and at least the initial offset. + * If the new offset is strictly less than len, all of the input buffers + * is guaranteed to have been consumed. If it is equal to len, no information + * is provided on how many bytes of the input buffer have been consumed. + **************************************************/ + +/* + * NOTE: The signature differs from the Kyber reference implementation + * in that it adds the offset and always expects the base of the target + * buffer. This avoids shifting the buffer base in the caller, which appears + * tricky to reason about. + */ +MLKEM_NATIVE_INTERNAL_API +unsigned int rej_uniform(int16_t *r, unsigned int target, unsigned int offset, + const uint8_t *buf, unsigned int buflen) +__contract__( + requires(offset <= target && target <= 4096 && buflen <= 4096 && buflen % 3 == 0) + requires(memory_no_alias(r, sizeof(int16_t) * target)) + requires(memory_no_alias(buf, buflen)) + requires(offset > 0 ==> array_bound(r, 0, offset, 0, MLKEM_Q)) + assigns(memory_slice(r, sizeof(int16_t) * target)) + ensures(offset <= return_value && return_value <= target) + ensures(return_value > 0 ==> array_bound(r, 0, return_value, 0, MLKEM_Q)) +); +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/symmetric.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/symmetric.h new file mode 100644 index 0000000000..55ebbbd533 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/symmetric.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef SYMMETRIC_H +#define SYMMETRIC_H + +#include +#include +#include "cbmc.h" +#include "common.h" +#include "fips202.h" + +/* Macros denoting FIPS-203 specific Hash functions */ + +/* Hash function H, FIPS-203 4.1 (eq 4.4) */ +#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) + +/* Hash function G, FIPS-203 4.1 (eq 4.5) */ +#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) + +/* Hash function J, FIPS-203 4.1 (eq 4.4) */ +#define hash_j(OUT, IN, INBYTES) shake256(OUT, MLKEM_SYMBYTES, IN, INBYTES) + +/* PRF function, FIPS-203 4.1 (eq 4.3) + * Referring to (eq 4.3), `OUT` is assumed to contain `s || b`. */ +#define prf_eta(ETA, OUT, IN) \ + shake256(OUT, (ETA) * MLKEM_N / 4, IN, MLKEM_SYMBYTES + 1) +#define prf_eta1(OUT, IN) prf_eta(MLKEM_ETA1, OUT, IN) +#define prf_eta2(OUT, IN) prf_eta(MLKEM_ETA2, OUT, IN) +#define prf_eta1_x4(OUT0, OUT1, OUT2, OUT3, IN0, IN1, IN2, IN3) \ + shake256x4(OUT0, OUT1, OUT2, OUT3, (MLKEM_ETA1 * MLKEM_N / 4), IN0, IN1, \ + IN2, IN3, MLKEM_SYMBYTES + 1) + +/* XOF function, FIPS-203 4.1 */ +#define xof_ctx shake128ctx +#define xof_x4_ctx shake128x4ctx +#define xof_absorb(CTX, IN, INBYTES) \ + shake128_absorb_once((CTX), (IN), (INBYTES)) +#define xof_squeezeblocks(BUF, NBLOCKS, CTX) \ + shake128_squeezeblocks((BUF), (NBLOCKS), (CTX)) +#define xof_release(CTX) shake128_release((CTX)) + +#define xof_x4_absorb(CTX, IN0, IN1, IN2, IN3, INBYTES) \ + shake128x4_absorb_once((CTX), (IN0), (IN1), (IN2), (IN3), (INBYTES)) +#define xof_x4_squeezeblocks(BUF0, BUF1, BUF2, BUF3, NBLOCKS, CTX) \ + shake128x4_squeezeblocks((BUF0), (BUF1), (BUF2), (BUF3), (NBLOCKS), (CTX)) +#define xof_x4_release(CTX) shake128x4_release((CTX)) + +#define XOF_RATE SHAKE128_RATE + +#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/sys.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/sys.h new file mode 100644 index 0000000000..a5820fa195 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/sys.h @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_NATIVE_SYS_H +#define MLKEM_NATIVE_SYS_H + +/* Check if we're running on an AArch64 little endian system. _M_ARM64 is set by + * MSVC. */ +#if defined(__AARCH64EL__) || defined(_M_ARM64) +#define SYS_AARCH64 +#endif + +/* Check if we're running on an AArch64 big endian system. */ +#if defined(__AARCH64EB__) +#define SYS_AARCH64_EB +#endif + +#if defined(__x86_64__) +#define SYS_X86_64 +#if defined(__AVX2__) +#define SYS_X86_64_AVX2 +#endif +#endif /* __x86_64__ */ + +/* Try to find endianness, if not forced through CFLAGS already */ +#if !defined(SYS_LITTLE_ENDIAN) && !defined(SYS_BIG_ENDIAN) +#if defined(__BYTE_ORDER__) +#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#define SYS_LITTLE_ENDIAN +#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define SYS_BIG_ENDIAN +#else /* __BYTE_ORER__ */ +#error "__BYTE_ORDER__ defined, but don't recognize value." +#endif /* __BYTE_ORER__ */ +#endif /* !defined(__BYTE_ORER__) */ +#endif /* defined(SYS_LITTLE_ENDIAN) || defined(SYS_BIG_ENDIAN) */ + +/* If FORCE_AARCH64 is set, assert that we're indeed on an AArch64 system. */ +#if defined(FORCE_AARCH64) && !defined(SYS_AARCH64) +#error "FORCE_AARCH64 is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_AARCH64_EB is set, assert that we're indeed on a big endian AArch64 + * system. */ +#if defined(FORCE_AARCH64_EB) && !defined(SYS_AARCH64_EB) +#error "FORCE_AARCH64_EB is set, but we don't seem to be on an AArch64 system." +#endif + +/* If FORCE_X86_64 is set, assert that we're indeed on an X86_64 system. */ +#if defined(FORCE_X86_64) && !defined(SYS_X86_64) +#error "FORCE_X86_64 is set, but we don't seem to be on an X86_64 system." +#endif + +/* + * C90 does not have the inline compiler directive yet. + * We don't use it in C90 builds. + * However, in that case the compiler warns about some inline functions in + * header files not being used in every compilation unit that includes that + * header. To work around it we silence that warning in that case using + * __attribute__((unused)). + */ + +/* Do not use inline for C90 builds*/ +#if !defined(INLINE) +#if !defined(inline) +#if defined(_MSC_VER) +#define INLINE __inline +#define ALWAYS_INLINE __forceinline +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#else +#define INLINE __attribute__((unused)) +#define ALWAYS_INLINE +#endif + +#else +#define INLINE inline +#define ALWAYS_INLINE __attribute__((always_inline)) +#endif +#endif + +/* + * C90 does not have the restrict compiler directive yet. + * We don't use it in C90 builds. + */ +#if !defined(restrict) +#if defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L +#define RESTRICT restrict +#else +#define RESTRICT +#endif + +#else + +#define RESTRICT restrict +#endif + +#define DEFAULT_ALIGN 32 +#if defined(_WIN32) +#define ALIGN __declspec(align(DEFAULT_ALIGN)) +#define asm __asm +#else +#define asm __asm__ +#define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) +#endif + +#endif /* MLKEM_NATIVE_SYS_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.c new file mode 100644 index 0000000000..b7078fcc19 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.c @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#include "verify.h" + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) +/* + * Masking value used in constant-time functions from + * verify.h to block the compiler's range analysis and + * thereby reduce the risk of compiler-introduced branches. + */ +volatile uint64_t ct_opt_blocker_u64 = 0; + +#else /* MLKEM_USE_ASM_VALUE_BARRIER */ + +#define empty_cu_verify MLKEM_NAMESPACE(empty_cu_verify) +int empty_cu_verify; + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.h new file mode 100644 index 0000000000..8c47155dcf --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/verify.h @@ -0,0 +1,317 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef VERIFY_H +#define VERIFY_H + +#include +#include +#include +#include "cbmc.h" +#include "common.h" + +/* Static namespacing + * This is to facilitate building multiple instances + * of mlkem-native (e.g. with varying security levels) + * within a single compilation unit. */ +#define value_barrier_u8 MLKEM_NAMESPACE(value_barrier_u8) +#define value_barrier_u32 MLKEM_NAMESPACE(value_barrier_u32) +#define value_barrier_i32 MLKEM_NAMESPACE(value_barrier_i32) +#define ct_cmask_neg_i16 MLKEM_NAMESPACE(ct_cmask_neg_i16) +#define ct_cmask_nonzero_u8 MLKEM_NAMESPACE(ct_cmask_nonzero_u8) +#define ct_cmask_nonzero_u16 MLKEM_NAMESPACE(ct_cmask_nonzero_u16) +#define ct_sel_uint8 MLKEM_NAMESPACE(ct_sel_uint8) +#define ct_sel_int16 MLKEM_NAMESPACE(ct_sel_int16) +#define ct_memcmp MLKEM_NAMESPACE(ct_memcmp) +#define ct_cmov_zero MLKEM_NAMESPACE(ct_cmov_zero) +/* End of static namespacing */ + +/* Constant-time comparisons and conditional operations + + We reduce the risk for compilation into variable-time code + through the use of 'value barriers'. + + Functionally, a value barrier is a no-op. To the compiler, however, + it constitutes an arbitrary modification of its input, and therefore + harden's value propagation and range analysis. + + We consider two approaches to implement a value barrier: + - An empty inline asm block which marks the target value as clobbered. + - XOR'ing with the value of a volatile global that's set to 0; + for a discussion / implementation of this idea, see e.g. + * https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/H14H0wOlBgAJ + * https://lib.mceliece.org/libmceliece-20240513/inttypes/crypto_intN.h.html + + The first approach is cheap because it only prevents the compiler + from reasoning about the value of the variable past the barrier, + but does not directly generate additional instructions. + + The second approach generates redundant loads and XOR operations + and therefore comes at a higher runtime cost. However, it appears + more robust towards optimization, as compilers should never drop + a volatile load. + + We use the empty-ASM value barrier for GCC and clang, and fall + back to the global volatile barrier otherwise. + + The global value barrier can be forced by setting MLKEM_NO_ASM_VALUE_BARRIER. + +*/ + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(CBMC) && \ + !defined(MLKEM_NO_ASM_VALUE_BARRIER) +#define MLKEM_USE_ASM_VALUE_BARRIER +#endif + +#if !defined(MLKEM_USE_ASM_VALUE_BARRIER) + +/* + * Declaration of global volatile that the global value barrier + * is loading from and masking with. + */ +#define ct_opt_blocker_u64 MLKEM_NAMESPACE(ct_opt_blocker_u64) +extern volatile uint64_t ct_opt_blocker_u64; + +/* Helper functions for obtaining masks of various sizes */ +static INLINE uint8_t get_optblocker_u8(void) +__contract__(ensures(return_value == 0)) { return (uint8_t)ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_u32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t get_optblocker_i32(void) +__contract__(ensures(return_value == 0)) { return ct_opt_blocker_u64; } + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u32()); } + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_i32()); } + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) { return (b ^ get_optblocker_u8()); } + +#else /* !MLKEM_USE_ASM_VALUE_BARRIER */ + +static INLINE uint32_t value_barrier_u32(uint32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE int32_t value_barrier_i32(int32_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +static INLINE uint8_t value_barrier_u8(uint8_t b) +__contract__(ensures(return_value == b)) +{ + asm("" : "+r"(b)); + return b; +} + +#endif /* MLKEM_USE_ASM_VALUE_BARRIER */ + +/* + * The ct_cmask_nonzero_xxx functions below make deliberate use of unsigned + * overflow, which is fully defined behaviour in C. It is thus safe to disable + * this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "unsigned-overflow" +#endif + +/************************************************* + * Name: ct_cmask_nonzero_u16 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_nonzero_u16(uint16_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFFFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 16; + return tmp; +} + +/************************************************* + * Name: ct_cmask_nonzero_u8 + * + * Description: Return 0 if input is zero, and -1 otherwise. + * + * Arguments: uint8_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint8_t ct_cmask_nonzero_u8(uint8_t x) +__contract__(ensures(return_value == ((x == 0) ? 0 : 0xFF))) +{ + uint32_t tmp = value_barrier_u32(-((uint32_t)x)); + tmp >>= 24; + return tmp; +} + +/* Put unsigned overflow warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_cmask_neg_i16 function below makes deliberate use of + * signed to unsigned integer conversion, which is fully defined + * behaviour in C. It is thus safe to disable this warning. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_cmask_neg_i16 + * + * Description: Return 0 if input is non-negative, and -1 otherwise. + * + * Arguments: uint16_t x: Value to be converted into a mask + **************************************************/ +static INLINE uint16_t ct_cmask_neg_i16(int16_t x) +__contract__(ensures(return_value == ((x < 0) ? 0xFFFF : 0))) +{ + int32_t tmp = value_barrier_i32((int32_t)x); + tmp >>= 16; + return (int16_t)tmp; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/* + * The ct_csel_xxx functions below make deliberate use of unsigned + * to signed integer conversion, which is implementation-defined + * behaviour. Here, we assume that uint16_t -> int16_t is inverse + * to int16_t -> uint16_t. + */ +#ifdef CBMC +#pragma CPROVER check push +#pragma CPROVER check disable "conversion" +#endif + +/************************************************* + * Name: ct_sel_int16 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: int16_t a: First alternative + * int16_t b: Second alternative + * uint16_t cond: Condition variable. + **************************************************/ +static INLINE int16_t ct_sel_int16(int16_t a, int16_t b, uint16_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + uint16_t au = a, bu = b; + uint16_t res = bu ^ (ct_cmask_nonzero_u16(cond) & (au ^ bu)); + return (int16_t)res; +} + +/* Put unsigned-to-signed warnings in CBMC back into scope */ +#ifdef CBMC +#pragma CPROVER check pop +#endif + +/************************************************* + * Name: ct_sel_uint8 + * + * Description: Functionally equivalent to cond ? a : b, + * but implemented with guards against + * compiler-introduced branches. + * + * Arguments: uint8_t a: First alternative + * uint8_t b: Second alternative + * uuint8_t cond: Condition variable. + **************************************************/ +static INLINE uint8_t ct_sel_uint8(uint8_t a, uint8_t b, uint8_t cond) +__contract__(ensures(return_value == (cond ? a : b))) +{ + return b ^ (ct_cmask_nonzero_u8(cond) & (a ^ b)); +} + +/************************************************* + * Name: ct_memcmp + * + * Description: Compare two arrays for equality in constant time. + * + * Arguments: const uint8_t *a: pointer to first byte array + * const uint8_t *b: pointer to second byte array + * size_t len: length of the byte arrays + * + * Returns 0 if the byte arrays are equal, a non-zero value otherwise + **************************************************/ +static INLINE uint8_t ct_memcmp(const uint8_t *a, const uint8_t *b, + const size_t len) +__contract__( + requires(memory_no_alias(a, len)) + requires(memory_no_alias(b, len)) + requires(len <= INT_MAX) + ensures((return_value == 0) == forall(i, 0, len, (a[i] == b[i])))) +{ + uint8_t r = 0, s = 0; + unsigned i; + + for (i = 0; i < len; i++) + __loop__( + invariant(i >= 0 && i <= len) + invariant((r == 0) == (forall(k, 0, i, (a[k] == b[k]))))) + { + r |= a[i] ^ b[i]; + /* s is useless, but prevents the loop from being aborted once r=0xff. */ + s ^= a[i] ^ b[i]; + } + + /* + * - Convert r into a mask; this may not be necessary, but is an additional + * safeguard + * towards leaking information about a and b. + * - XOR twice with s, separated by a value barrier, to prevent the compile + * from dropping the s computation in the loop. + */ + return (value_barrier_u8(ct_cmask_nonzero_u8(r) ^ s) ^ s); +} + +/************************************************* + * Name: ct_cmov_zero + * + * Description: Copy len bytes from x to r if b is zero; + * don't modify x if b is non-zero. + * assumes two's complement representation of negative integers. + * Runs in constant time. + * + * Arguments: uint8_t *r: pointer to output byte array + * const uint8_t *x: pointer to input byte array + * size_t len: Amount of bytes to be copied + * uint8_t b: Condition value. + **************************************************/ +static INLINE void ct_cmov_zero(uint8_t *r, const uint8_t *x, size_t len, + uint8_t b) +__contract__( + requires(memory_no_alias(r, len)) + requires(memory_no_alias(x, len)) + assigns(memory_slice(r, len))) +{ + size_t i; + for (i = 0; i < len; i++) + __loop__(invariant(i <= len)) + { + r[i] = ct_sel_uint8(r[i], x[i], b); + } +} + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/README.md b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/README.md new file mode 100644 index 0000000000..2073425c3b --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/README.md @@ -0,0 +1,4 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +This directory contains the native x86_64 arithmetic backend for ML-KEM provided by the official [AVX2 +implementation](https://github.com/pq-crystals/kyber/tree/main/avx2) of the Kyber team. diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/default.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/default.h new file mode 100644 index 0000000000..592e8996dc --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/default.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_H + +/* Identifier for this backend so that source and assembly files + * in the build can be appropriately guarded. */ +#define MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT + +#define MLKEM_NATIVE_ARITH_BACKEND_NAME X86_64_DEFAULT + +/* Filename of the C backend implementation. + * This is not inlined here because this header is included in assembly + * files as well. */ +#define MLKEM_NATIVE_ARITH_BACKEND_IMPL "x86_64/src/default_impl.h" + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_H */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/align.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/align.h new file mode 100644 index 0000000000..42a02fe57c --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/align.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/align.h + */ + +#ifndef ALIGN_H +#define ALIGN_H + +#include +#include + +#define ALIGNED_UINT8(N) \ + union \ + { \ + uint8_t coeffs[N]; \ + __m256i vec[(N + 31) / 32]; \ + } + +#define ALIGNED_INT16(N) \ + union \ + { \ + int16_t coeffs[N]; \ + __m256i vec[(N + 15) / 16]; \ + } + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/arith_native_x86_64.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/arith_native_x86_64.h new file mode 100644 index 0000000000..ce13e7911f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/arith_native_x86_64.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ +#ifndef MLKEM_X86_64_NATIVE_H +#define MLKEM_X86_64_NATIVE_H + +#include "common.h" + +#include +#include +#include "polyvec.h" +#include "consts.h" + +#define REJ_UNIFORM_AVX_NBLOCKS 3 /* See MLKEM_GEN_MATRIX_NBLOCKS */ +#define REJ_UNIFORM_AVX_BUFLEN \ + (3 * 168) /* REJ_UNIFORM_AVX_BUFLEN * SHAKE128_RATE */ + +#define rej_uniform_avx2 MLKEM_NAMESPACE(rej_uniform_avx2) +unsigned int rej_uniform_avx2(int16_t *r, const uint8_t *buf); + +#define rej_uniform_table MLKEM_NAMESPACE(rej_uniform_table) +extern const uint8_t rej_uniform_table[256][8]; + +#define ntt_avx2 MLKEM_NAMESPACE(ntt_avx2) +void ntt_avx2(__m256i *r, const __m256i *qdata); + +#define invntt_avx2 MLKEM_NAMESPACE(invntt_avx2) +void invntt_avx2(__m256i *r, const __m256i *qdata); + +#define nttpack_avx2 MLKEM_NAMESPACE(nttpack_avx2) +void nttpack_avx2(__m256i *r, const __m256i *qdata); + +#define nttunpack_avx2 MLKEM_NAMESPACE(nttunpack_avx2) +void nttunpack_avx2(__m256i *r, const __m256i *qdata); + +#define reduce_avx2 MLKEM_NAMESPACE(reduce_avx2) +void reduce_avx2(__m256i *r, const __m256i *qdata); + +#define basemul_avx2 MLKEM_NAMESPACE(basemul_avx2) +void basemul_avx2(__m256i *r, const __m256i *a, const __m256i *b, + const __m256i *qdata); + +#define polyvec_basemul_acc_montgomery_cached_avx2 \ + MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_avx2) +void polyvec_basemul_acc_montgomery_cached_avx2( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache); + +#define ntttobytes_avx2 MLKEM_NAMESPACE(ntttobytes_avx2) +void ntttobytes_avx2(uint8_t *r, const __m256i *a, const __m256i *qdata); + +#define nttfrombytes_avx2 MLKEM_NAMESPACE(nttfrombytes_avx2) +void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata); + +#define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2) +void tomont_avx2(__m256i *r, const __m256i *qdata); + +#endif /* MLKEM_X86_64_NATIVE_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/basemul.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.S similarity index 61% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/basemul.S rename to src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.S index 36990639b2..b97840e702 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/basemul.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.S @@ -1,12 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" +/* Polynomials to be multiplied are denoted a+bX (rsi arg) and c+dX (rdx arg) */ .macro schoolbook off -vmovdqa _16XQINV*2(%rcx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQINV*2(%rcx),%ymm0 vmovdqa (64*\off+ 0)*2(%rsi),%ymm1 # a0 vmovdqa (64*\off+16)*2(%rsi),%ymm2 # b0 vmovdqa (64*\off+32)*2(%rsi),%ymm3 # a1 vmovdqa (64*\off+48)*2(%rsi),%ymm4 # b1 +/* Prepare Montgomery twists */ vpmullw %ymm0,%ymm1,%ymm9 # a0.lo vpmullw %ymm0,%ymm2,%ymm10 # b0.lo vpmullw %ymm0,%ymm3,%ymm11 # a1.lo @@ -15,6 +28,7 @@ vpmullw %ymm0,%ymm4,%ymm12 # b1.lo vmovdqa (64*\off+ 0)*2(%rdx),%ymm5 # c0 vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0 +/* Compute high-parts of monomials in (a0+b0*X)*(c0+d0*X) */ vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi @@ -23,6 +37,8 @@ vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1 vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1 +/* Compute high-parts of monomials in (a1+b1*X)*(c1+d1*X) */ +/* Don't yet accumulate nor reduce X^2 */ vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi @@ -30,17 +46,22 @@ vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi vmovdqa %ymm13,(%rsp) +/* Compute low-parts of monomials in (a0+b0*X)*(c0+d0*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo +/* Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X), */ +/* using Montgomery twists calculated before */ vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo -vmovdqa _16XQ*2(%rcx),%ymm8 +/* Compute 2nd high multiplication in Montgomery multiplication */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rcx),%ymm8 vpmulhw %ymm8,%ymm13,%ymm13 vpmulhw %ymm8,%ymm9,%ymm9 vpmulhw %ymm8,%ymm5,%ymm5 @@ -50,6 +71,7 @@ vpmulhw %ymm8,%ymm11,%ymm11 vpmulhw %ymm8,%ymm7,%ymm7 vpmulhw %ymm8,%ymm12,%ymm12 +/* Finish Montgomery multiplications */ vpsubw (%rsp),%ymm13,%ymm13 # -a0c0 vpsubw %ymm9,%ymm1,%ymm9 # a0d0 vpsubw %ymm5,%ymm14,%ymm5 # b0c0 @@ -60,6 +82,10 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1 vpsubw %ymm7,%ymm0,%ymm7 # b1c1 vpsubw %ymm12,%ymm4,%ymm12 # b1d1 +/* b0*d0 and b1*d1 need twisting by a twiddle, accounting + * for X^2=zeta in F_q[X]/(X^2-zeta). + * + * TODO: This could be precomputed in the mulcache */ vmovdqa (%r9),%ymm0 vmovdqa 32(%r9),%ymm1 vpmullw %ymm0,%ymm10,%ymm2 @@ -76,6 +102,9 @@ vpaddw %ymm7,%ymm11,%ymm11 vpsubw %ymm13,%ymm10,%ymm13 vpsubw %ymm12,%ymm6,%ymm6 +/* Bounds: Since we are multiplying with signed canonical twiddles, + * each Montgomery multiplication has absolute value < q, + * and hence the coefficients of the output have absolute value < 2q. */ vmovdqa %ymm13,(64*\off+ 0)*2(%rdi) vmovdqa %ymm9,(64*\off+16)*2(%rdi) vmovdqa %ymm6,(64*\off+32)*2(%rdi) @@ -83,13 +112,13 @@ vmovdqa %ymm11,(64*\off+48)*2(%rdi) .endm .text -.global cdecl(basemul_avx) -cdecl(basemul_avx): +.global MLKEM_ASM_NAMESPACE(basemul_avx2) +MLKEM_ASM_NAMESPACE(basemul_avx2): mov %rsp,%r8 and $-32,%rsp sub $32,%rsp -lea (_ZETAS_EXP+176)*2(%rcx),%r9 +lea (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+176)*2(%rcx),%r9 schoolbook 0 add $32*2,%r9 @@ -103,3 +132,5 @@ schoolbook 3 mov %r8,%rsp ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.c new file mode 100644 index 0000000000..5f9ae99c80 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/basemul.c @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "poly.h" +#include "polyvec.h" + +#include "arith_native_x86_64.h" +#include "consts.h" + +static void poly_basemul_montgomery_avx2(poly *r, const poly *a, const poly *b) +{ + basemul_avx2((__m256i *)r->coeffs, (const __m256i *)a->coeffs, + (const __m256i *)b->coeffs, qdata.vec); +} + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ +static void poly_add_avx2(poly *r, const poly *a, const poly *b) +{ + unsigned i; + __m256i f0, f1; + + for (i = 0; i < MLKEM_N; i += 16) + { + f0 = _mm256_load_si256((const __m256i *)&a->coeffs[i]); + f1 = _mm256_load_si256((const __m256i *)&b->coeffs[i]); + f0 = _mm256_add_epi16(f0, f1); + _mm256_store_si256((__m256i *)&r->coeffs[i], f0); + } +} + +void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a, + const polyvec *b, + const polyvec_mulcache *b_cache) +{ + unsigned i; + poly t; + + /* TODO: Use mulcache for AVX2. So far, it is unused. */ + ((void)b_cache); + + /* Coefficient-wise bound of each basemul is 2q. + * Since we are accumulating at most 4 times, the + * overall bound is 8q < INT16_MAX. */ + poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]); + for (i = 1; i < MLKEM_K; i++) + { + poly_basemul_montgomery_avx2(&t, &a->vec[i], &b->vec[i]); + poly_add_avx2(r, r, &t); + } +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy constant to keep compiler happy despite empty CU */ + +#define empty_cu_avx2_basemul MLKEM_NAMESPACE(empty_cu_avx2_basemul) +int empty_cu_avx2_basemul; + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.c new file mode 100644 index 0000000000..86a0835efd --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.c @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.c + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "align.h" +#include "consts.h" + +#define Q MLKEM_Q +#define MONT -1044 /* 2^16 mod q */ +#define QINV -3327 /* q^-1 mod 2^16 */ +#define V 20159 /* floor(2^26/q + 0.5) */ +#define FHI 1441 /* mont^2/128 */ +#define FLO -10079 /* qinv*FHI */ +#define MONTSQHI 1353 /* mont^2 */ +#define MONTSQLO 20553 /* qinv*MONTSQHI */ +#define MASK 4095 +#define SHIFT 32 + +const qdata_t qdata = {{ +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, Q, Q, + Q, Q, Q, Q, + +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, QINV, QINV, + QINV, QINV, QINV, QINV, + +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 + V, V, V, V, V, V, + V, V, V, V, V, V, + V, V, V, V, + +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, FLO, FLO, + FLO, FLO, FLO, FLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, FHI, FHI, + FHI, FHI, FHI, FHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, + +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, + +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, MASK, MASK, + MASK, MASK, MASK, MASK, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 + 3854, 3340, 2826, 2312, 1798, 1284, + 770, 256, 3854, 3340, 2826, 2312, + 1798, 1284, 770, 256, + +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 + 7, 0, 6, 0, 5, 0, + 4, 0, 3, 0, 2, 0, + 1, 0, 0, 0, + +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#include "x86_64_zetas.i" + +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, + SHIFT, SHIFT, SHIFT, SHIFT}}; + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_consts MLKEM_NAMESPACE(empty_cu_consts) +int empty_cu_consts; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.h new file mode 100644 index 0000000000..00c415952e --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/consts.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2/consts.h + */ + +#ifndef CONSTS_H +#define CONSTS_H + +#include "common.h" + +#define AVX2_BACKEND_DATA_OFFSET_16XQ 0 +#define AVX2_BACKEND_DATA_OFFSET_16XQINV 16 +#define AVX2_BACKEND_DATA_OFFSET_16XV 32 +#define AVX2_BACKEND_DATA_OFFSET_16XFLO 48 +#define AVX2_BACKEND_DATA_OFFSET_16XFHI 64 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO 80 +#define AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI 96 +#define AVX2_BACKEND_DATA_OFFSET_16XMASK 112 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXB 128 +#define AVX2_BACKEND_DATA_OFFSET_REVIDXD 144 +#define AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP 160 +#define AVX2_BACKEND_DATA_OFFSET_16XSHIFT 624 + +/* The C ABI on MacOS exports all symbols with a leading + * underscore. This means that any symbols we refer to from + * C files (functions) can't be found, and all symbols we + * refer to from ASM also can't be found. + * + * This define helps us get around this + */ + +#ifndef __ASSEMBLER__ +#include "align.h" +typedef ALIGNED_INT16(640) qdata_t; +#define qdata MLKEM_NAMESPACE(qdata) +extern const qdata_t qdata; +#endif + +#endif diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/default_impl.h b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/default_impl.h new file mode 100644 index 0000000000..66de8c85f3 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/default_impl.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* ML-KEM arithmetic native profile for clean assembly */ + +#ifdef MLKEM_NATIVE_ARITH_PROFILE_IMPL_H +#error Only one MLKEM_ARITH assembly profile can be defined -- did you include multiple profiles? +#else +#define MLKEM_NATIVE_ARITH_PROFILE_IMPL_H + +#include + +#include "poly.h" +#include "polyvec.h" +#include "arith_native_x86_64.h" + +#define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER + +#define MLKEM_USE_NATIVE_REJ_UNIFORM +#define MLKEM_USE_NATIVE_NTT +#define MLKEM_USE_NATIVE_INTT +#define MLKEM_USE_NATIVE_POLY_REDUCE +#define MLKEM_USE_NATIVE_POLY_TOMONT +#define MLKEM_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED +#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE +#define MLKEM_USE_NATIVE_POLY_TOBYTES +#define MLKEM_USE_NATIVE_POLY_FROMBYTES + +#define INVNTT_BOUND_NATIVE (8 * MLKEM_Q) +#define NTT_BOUND_NATIVE (8 * MLKEM_Q) + +static INLINE void poly_permute_bitrev_to_custom(poly *data) +{ + nttunpack_avx2((__m256i *)(data->coeffs), qdata.vec); +} + +static INLINE int rej_uniform_native(int16_t *r, unsigned int len, + const uint8_t *buf, unsigned int buflen) +{ + /* AVX2 implementation assumes specific buffer lengths */ + if (len != MLKEM_N || buflen != REJ_UNIFORM_AVX_BUFLEN) + { + return -1; + } + + return (int)rej_uniform_avx2(r, buf); +} + +static INLINE void ntt_native(poly *data) +{ + ntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void intt_native(poly *data) +{ + invntt_avx2((__m256i *)data, qdata.vec); +} + +static INLINE void poly_reduce_native(poly *data) +{ + reduce_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_tomont_native(poly *data) +{ + tomont_avx2((__m256i *)data->coeffs, qdata.vec); +} + +static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y) +{ + /* AVX2 backend does not use mulcache */ + ((void)y); + ((void)x); +} + +static INLINE void polyvec_basemul_acc_montgomery_cached_native( + poly *r, const polyvec *a, const polyvec *b, + const polyvec_mulcache *b_cache) +{ + polyvec_basemul_acc_montgomery_cached_avx2(r, a, b, b_cache); +} + +static INLINE void poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES], + const poly *a) +{ + ntttobytes_avx2(r, (const __m256i *)a->coeffs, qdata.vec); +} + +static INLINE void poly_frombytes_native(poly *r, + const uint8_t a[MLKEM_POLYBYTES]) +{ + nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec); +} + +#endif /* MLKEM_NATIVE_ARITH_PROFILE_IMPL_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.S similarity index 50% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.S rename to src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.S index 3bb1ebd3d8..134bd4f710 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/fq.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.S @@ -1,8 +1,25 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation based on Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +// Changes: +// - Add call to csub in reduce128_avx to produce outputs +// in [0,1,...,q-1] rather than [0,1,...,q], matching the +// semantics of poly_reduce(). + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) #include "consts.h" -.include "fq.inc" + +#include "fq.inc" .text -reduce128_avx: +reduce128_avx2: #load vmovdqa (%rdi),%ymm2 vmovdqa 32(%rdi),%ymm3 @@ -22,6 +39,15 @@ red16 7 red16 8 red16 9 +csubq 2 +csubq 3 +csubq 4 +csubq 5 +csubq 6 +csubq 7 +csubq 8 +csubq 9 + #store vmovdqa %ymm2,(%rdi) vmovdqa %ymm3,32(%rdi) @@ -34,17 +60,18 @@ vmovdqa %ymm9,224(%rdi) ret -.global cdecl(reduce_avx) -cdecl(reduce_avx): +.global MLKEM_ASM_NAMESPACE(reduce_avx2) +MLKEM_ASM_NAMESPACE(reduce_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XV*2(%rsi),%ymm1 -call reduce128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XV*2(%rsi),%ymm1 +call reduce128_avx2 add $256,%rdi -call reduce128_avx +call reduce128_avx2 ret -tomont128_avx: + +tomont128_avx2: #load vmovdqa (%rdi),%ymm3 vmovdqa 32(%rdi),%ymm4 @@ -76,13 +103,15 @@ vmovdqa %ymm10,224(%rdi) ret -.global cdecl(tomont_avx) -cdecl(tomont_avx): +.global MLKEM_ASM_NAMESPACE(tomont_avx2) +MLKEM_ASM_NAMESPACE(tomont_avx2): #consts -vmovdqa _16XQ*2(%rsi),%ymm0 -vmovdqa _16XMONTSQLO*2(%rsi),%ymm1 -vmovdqa _16XMONTSQHI*2(%rsi),%ymm2 -call tomont128_avx +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQLO*2(%rsi),%ymm1 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMONTSQHI*2(%rsi),%ymm2 +call tomont128_avx2 add $256,%rdi -call tomont128_avx +call tomont128_avx2 ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.inc b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.inc similarity index 67% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.inc index 4b7afc3118..76ec7a3b9e 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/fq.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/fq.inc @@ -1,3 +1,13 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + .macro red16 r,rs=0,x=12 vpmulhw %ymm1,%ymm\r,%ymm\x .if \rs @@ -22,6 +32,8 @@ vpand %ymm0,%ymm\x,%ymm\x vpaddw %ymm\x,%ymm\r,%ymm\r .endm +/* Montgomery multiplication between b and ah, + * with Montgomery twist of ah in al. */ .macro fqmulprecomp al,ah,b,x=12 vpmullw %ymm\al,%ymm\b,%ymm\x vpmulhw %ymm\ah,%ymm\b,%ymm\b diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/intt.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/intt.S new file mode 100644 index 0000000000..6b1d78ef26 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/intt.S @@ -0,0 +1,255 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* Implementation based on Kyber repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + * + * Changes to placement of modular reductions have + * been made to simplify reasoning of non-overflow */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" +#include "fq.inc" + +/* Compute four GS butterflies between rh{0,1,2,3} and rl{0,1,2,3}. + * Butterflies 0,1 use root zh0 and twisted root zl0, and butterflies + * 2,3 use root zh1 and twisted root zl1 + * Results are again in rl{0-3} and rh{0-3} */ +.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 +vpsubw %ymm\rl0,%ymm\rh0,%ymm12 /* ymm12 = rh0 - rl0 */ +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 /* rl0 = rh0 + rl0 */ +vpsubw %ymm\rl1,%ymm\rh1,%ymm13 /* ymm13 = rh1 - rl1 */ + +vpmullw %ymm\zl0,%ymm12,%ymm\rh0 /* rh0 = (rh0 - rl0) * root0_twisted */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 /* rl1 = rh1 + rh1 */ +vpsubw %ymm\rl2,%ymm\rh2,%ymm14 /* ymm14 = rh2 - rl2 */ + +vpmullw %ymm\zl0,%ymm13,%ymm\rh1 /* rh1 = (rh1 - rl1) * root0_twisted */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 /* rl2 = rh2 + rl2 */ +vpsubw %ymm\rl3,%ymm\rh3,%ymm15 /* ymm15 = rh3 - rl3 */ + +vpmullw %ymm\zl1,%ymm14,%ymm\rh2 /* rh2 = (rh2 - rl2) * root1_twisted */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 /* rl3 = rh3 + rl3 */ +vpmullw %ymm\zl1,%ymm15,%ymm\rh3 /* rh3 = (rh3 - rl3) * root1_twisted */ + +vpmulhw %ymm\zh0,%ymm12,%ymm12 /* ymm12 = (rh0 - rl0) * root0 */ +vpmulhw %ymm\zh0,%ymm13,%ymm13 /* ymm13 = (rh1 - rl1) * root0 */ + +vpmulhw %ymm\zh1,%ymm14,%ymm14 /* ymm14 = (rh2 - rl2) * root1 */ +vpmulhw %ymm\zh1,%ymm15,%ymm15 /* ymm15 = (rh3 - rl3) * root1 */ + +vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 /* rh0 = Q * [(rh0 - rl0) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 /* rh1 = Q * [(rh1 - rl1) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 /* rh2 = Q * [(rh2 - rl2) * root0_twisted] */ +vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 /* rh3 = Q * [(rh3 - rl3) * root0_twisted] */ + +vpsubw %ymm\rh0,%ymm12,%ymm\rh0 /* rh0 = montmul(rh0-rl0, root0) */ +vpsubw %ymm\rh1,%ymm13,%ymm\rh1 /* rh1 = montmul(rh1-rl1, root0) */ +vpsubw %ymm\rh2,%ymm14,%ymm\rh2 /* rh2 = montmul(rh2-rl2, root0) */ +vpsubw %ymm\rh3,%ymm15,%ymm\rh3 /* rh3 = montmul(rh3-rl3, root0) */ +.endm + +.macro intt_levels0t5 off +/* level 0 */ +/* no bounds assumptions */ +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFLO*2(%rsi),%ymm2 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XFHI*2(%rsi),%ymm3 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +fqmulprecomp 2,3,4 +fqmulprecomp 2,3,6 +fqmulprecomp 2,3,5 +fqmulprecomp 2,3,7 + +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 + +fqmulprecomp 2,3,8 +fqmulprecomp 2,3,10 +fqmulprecomp 2,3,9 +fqmulprecomp 2,3,11 + +/* bounds: coefficients < q */ + +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm12 +vpshufb %ymm12,%ymm15,%ymm15 +vpshufb %ymm12,%ymm1,%ymm1 +vpshufb %ymm12,%ymm2,%ymm2 +vpshufb %ymm12,%ymm3,%ymm3 + +butterfly 4,5,8,9,6,7,10,11,15,1,2,3 + +/* Montgmoery multiplication with a signed canonical twiddle + * always has absolute value < q. This is used henceforth to + * normalize the absolute bounds on the second half inputs + * to the current butterfly + * + * 4,5,8,9 abs bound < 2q; 6,7,10,11 abs bound < q */ + +/* level 1 */ +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 +vpermq $0x4E,(AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 +vmovdqa AVX2_BACKEND_DATA_OFFSET_REVIDXB*2(%rsi),%ymm1 +vpshufb %ymm1,%ymm2,%ymm2 +vpshufb %ymm1,%ymm3,%ymm3 + +butterfly 4,5,6,7,8,9,10,11,2,2,3,3 + +/* For 8,9,10,11, it is sufficient to use the bound INT16_MAX). */ +red16 7 +/* global abs bound < 4q */ + +vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.macro intt_level6 off +/* level 6 */ +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm2 + +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm3 + +butterfly 4,5,6,7,8,9,10,11 +/* global abs bound < 8q */ + +/* REF-CHANGE: The official AVX2 implementation has a `red16 4` for `off=0`. + * We don't need this because of the earlier red16 which ensures an 8q bound */ + +vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(invntt_avx2) +MLKEM_ASM_NAMESPACE(invntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +intt_levels0t5 0 +intt_levels0t5 1 + +intt_level6 0 +intt_level6 1 +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/ntt.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/ntt.S new file mode 100644 index 0000000000..e8bf7894b4 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/ntt.S @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include "consts.h" +#include "shuffle.inc" + +/* Compute steps 1,2 / 3 of Montgomery multiplication */ +.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 +vpmullw %ymm\zl0,%ymm\rh0,%ymm12 +vpmullw %ymm\zl0,%ymm\rh1,%ymm13 + +vpmullw %ymm\zl1,%ymm\rh2,%ymm14 +vpmullw %ymm\zl1,%ymm\rh3,%ymm15 + +vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 +vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 + +vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 +vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 +.endm + +/* Compute step 3 / 3 of Montgomery multiplication */ +/* Multiply-high is signed; outputs are bound by 2^15 * q in abs value */ +.macro reduce +vpmulhw %ymm0,%ymm12,%ymm12 +vpmulhw %ymm0,%ymm13,%ymm13 + +vpmulhw %ymm0,%ymm14,%ymm14 +vpmulhw %ymm0,%ymm15,%ymm15 +.endm + +/* Finish Montgomery multiplication and compute add/sub steps in NTT butterfly + * + * At this point, the two high-products of 4 ongoing Montgomery multiplications + * are in %ymm{12,13,14,15} and %ymm{rh{0,1,2,3}}, respectively. + * The NTT coefficients that the results of the Montgomery multiplications should + * be add/sub-ed with, are in %ymm{rl{0,1,2,3}}. + * + * What's interesting, here, is that rather than completing the Montgomery + * multiplications by computing `%ymm{12+i} + %ymm{rh{i}}`, and then add/sub'ing + * the result into %ymm{rl{0,1,2,3}}, we add/sub both `%ymm{12+i}` and + * %ymm{rh{i}} to %ymm{rl{0,1,2,3}}, and then add the results. + * + * Functionally, though, this is still a signed Montgomery multiplication + * followed by an add/sub. + * + * Since the result of the Montgomery multiplication is bounded + * by q in absolute value, the coefficients overall grow by not + * more than q in absolute value per layer. */ +.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln /* rln = rl0 + rh0 */ +vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 /* rh0 = rl0 - rh0 */ +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 /* rl0 = rl1 + rh1 */ +vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 /* rh1 = rl1 - rh1 */ +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 /* rl1 = rl2 + rh2 */ +vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 /* rh2 = rl2 - rh2 */ +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 /* rl2 = rl3 + rh3 */ +vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 /* rh3 = rl3 - rh3 */ + +vpsubw %ymm12,%ymm\rln,%ymm\rln /* rln = rh0 + rl0 - ymm12 = rl0 + (rh0 - ymm12) */ +vpaddw %ymm12,%ymm\rh0,%ymm\rh0 /* rh0 = rl0 - rh0 + ymm12 = rl0 - (rh0 - ymm12) */ +vpsubw %ymm13,%ymm\rl0,%ymm\rl0 /* rl0 = rl1 + rh1 - ymm13 = rl1 + (rh1 - ymm13) */ +vpaddw %ymm13,%ymm\rh1,%ymm\rh1 /* rh1 = rl1 - rh1 + ymm13 = rl1 - (rh1 - ymm13) */ +vpsubw %ymm14,%ymm\rl1,%ymm\rl1 /* rl1 = rh2 + rl2 - ymm14 = rl2 + (rh2 - ymm14) */ +vpaddw %ymm14,%ymm\rh2,%ymm\rh2 /* rh2 = rl2 - rh2 + ymm14 = rl2 - (rh2 - ymm14) */ +vpsubw %ymm15,%ymm\rl2,%ymm\rl2 /* rl2 = rh3 + rl3 - ymm15 = rl3 + (rh3 - ymm15) */ +vpaddw %ymm15,%ymm\rh3,%ymm\rh3 /* rh3 = rl3 - rh3 + ymm15 = rl3 - (rh3 - ymm15) */ +.endm + +.macro level0 off +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+0)*2(%rsi),%ymm15 +vmovdqa (64*\off+128)*2(%rdi),%ymm8 +vmovdqa (64*\off+144)*2(%rdi),%ymm9 +vmovdqa (64*\off+160)*2(%rdi),%ymm10 +vmovdqa (64*\off+176)*2(%rdi),%ymm11 +vpbroadcastq (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+4)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) +vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) +vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) +vmovdqa %ymm8,(64*\off+128)*2(%rdi) +vmovdqa %ymm9,(64*\off+144)*2(%rdi) +vmovdqa %ymm10,(64*\off+160)*2(%rdi) +vmovdqa %ymm11,(64*\off+176)*2(%rdi) +.endm + +.macro levels1t6 off +/* level 1 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 +vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 +vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 +vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 +vmovdqa (128*\off+112)*2(%rdi),%ymm11 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 + +mul 8,9,10,11 + +vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 +vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 +vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 +vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 + +reduce +update 3,4,5,6,7,8,9,10,11 + +/* level 2 */ +shuffle8 5,10,7,10 +shuffle8 6,11,5,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 + +mul 7,10,5,11 + +shuffle8 3,8,6,8 +shuffle8 4,9,3,9 + +reduce +update 4,6,8,3,9,7,10,5,11 + +/* level 3 */ +shuffle4 8,5,9,5 +shuffle4 3,11,8,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 + +mul 9,5,8,11 + +shuffle4 4,7,3,7 +shuffle4 6,10,4,10 + +reduce +update 6,3,7,4,10,9,5,8,11 + +/* level 4 */ +shuffle2 7,8,10,8 +shuffle2 4,11,7,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 + +mul 10,8,7,11 + +shuffle2 6,9,4,9 +shuffle2 3,5,6,5 + +reduce +update 3,4,9,6,5,10,8,7,11 + +/* level 5 */ +shuffle1 9,7,5,7 +shuffle1 6,11,9,11 + +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 + +mul 5,7,9,11 + +shuffle1 3,10,6,10 +shuffle1 4,8,3,8 + +reduce +update 4,6,10,3,8,5,7,9,11 + +/* level 6 */ +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 +vmovdqa (AVX2_BACKEND_DATA_OFFSET_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 + +mul 10,3,9,11,14,15,8,2 + +reduce +update 8,4,6,5,7,10,3,9,11 + +vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) +vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) +vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) +vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) +vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) +vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) +vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) +vmovdqa %ymm11,(128*\off+112)*2(%rdi) +.endm + +.text +.global MLKEM_ASM_NAMESPACE(ntt_avx2) +MLKEM_ASM_NAMESPACE(ntt_avx2): +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rsi),%ymm0 + +level0 0 +level0 1 + +levels1t6 0 +levels1t6 1 + +ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_avx2.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_avx2.c new file mode 100644 index 0000000000..54037a0df9 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_avx2.c @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * Implementation from Kyber reference repository + * https://github.com/pq-crystals/kyber/blob/main/avx2 + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include +#include +#include "arith_native_x86_64.h" +#include "consts.h" + +unsigned int rej_uniform_avx2(int16_t *RESTRICT r, const uint8_t *buf) +{ + unsigned int ctr, pos; + uint16_t val0, val1; + uint32_t good; + const __m256i bound = + _mm256_load_si256(&qdata.vec[AVX2_BACKEND_DATA_OFFSET_16XQ / 16]); + const __m256i ones = _mm256_set1_epi8(1); + const __m256i mask = _mm256_set1_epi16(0xFFF); + const __m256i idx8 = + _mm256_set_epi8(15, 14, 14, 13, 12, 11, 11, 10, 9, 8, 8, 7, 6, 5, 5, 4, + 11, 10, 10, 9, 8, 7, 7, 6, 5, 4, 4, 3, 2, 1, 1, 0); + __m256i f0, f1, g0, g1, g2, g3; + __m128i f, t, pilo, pihi; + + ctr = pos = 0; + while (ctr <= MLKEM_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 48) + { + f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); + /* Don't load from offset 24, as this would over-read the buffer */ + f1 = _mm256_loadu_si256((__m256i *)&buf[pos + 16]); + f0 = _mm256_permute4x64_epi64(f0, 0x94 /* 0b10010100 ~= (2,1,1,0) */); + f1 = _mm256_permute4x64_epi64(f1, 0xe9 /* 0x11101001 ~= (3,2,2,1) */); + f0 = _mm256_shuffle_epi8(f0, idx8); + f1 = _mm256_shuffle_epi8(f1, idx8); + g0 = _mm256_srli_epi16(f0, 4); + g1 = _mm256_srli_epi16(f1, 4); + f0 = _mm256_blend_epi16(f0, g0, 0xAA); + f1 = _mm256_blend_epi16(f1, g1, 0xAA); + f0 = _mm256_and_si256(f0, mask); + f1 = _mm256_and_si256(f1, mask); + pos += 48; + + g0 = _mm256_cmpgt_epi16(bound, f0); + g1 = _mm256_cmpgt_epi16(bound, f1); + + g0 = _mm256_packs_epi16(g0, g1); + good = _mm256_movemask_epi8(g0); + + g0 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 0) & 0xFF])); + g1 = _mm256_castsi128_si256( + _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 8) & 0xFF])); + g0 = _mm256_inserti128_si256( + g0, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 16) & 0xFF]), + 1); + g1 = _mm256_inserti128_si256( + g1, _mm_loadl_epi64((__m128i *)&rej_uniform_table[(good >> 24) & 0xFF]), + 1); + + g2 = _mm256_add_epi8(g0, ones); + g3 = _mm256_add_epi8(g1, ones); + g0 = _mm256_unpacklo_epi8(g0, g2); + g1 = _mm256_unpacklo_epi8(g1, g3); + + f0 = _mm256_shuffle_epi8(f0, g0); + f1 = _mm256_shuffle_epi8(f1, g1); + + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); + ctr += _mm_popcnt_u32((good >> 0) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); + ctr += _mm_popcnt_u32((good >> 16) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); + ctr += _mm_popcnt_u32((good >> 8) & 0xFF); + _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); + ctr += _mm_popcnt_u32((good >> 24) & 0xFF); + } + + while (ctr <= MLKEM_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 24) + { + f = _mm_loadu_si128((__m128i *)&buf[pos]); + f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); + t = _mm_srli_epi16(f, 4); + f = _mm_blend_epi16(f, t, 0xAA); + f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); + pos += 12; + + t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); + good = _mm_movemask_epi8(t); + + good = _pext_u32(good, 0x5555); + pilo = _mm_loadl_epi64((__m128i *)&rej_uniform_table[good]); + + pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); + pilo = _mm_unpacklo_epi8(pilo, pihi); + f = _mm_shuffle_epi8(f, pilo); + _mm_storeu_si128((__m128i *)&r[ctr], f); + ctr += _mm_popcnt_u32(good); + } + + while (ctr < MLKEM_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) + { + val0 = ((buf[pos + 0] >> 0) | ((uint16_t)buf[pos + 1] << 8)) & 0xFFF; + val1 = ((buf[pos + 1] >> 4) | ((uint16_t)buf[pos + 2] << 4)); + pos += 3; + + if (val0 < MLKEM_Q) + r[ctr++] = val0; + if (val1 < MLKEM_Q && ctr < MLKEM_N) + r[ctr++] = val1; + } + + return ctr; +} + +#else /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_rej_uniform_avx2 MLKEM_NAMESPACE(empty_cu_rej_uniform_avx2) +int empty_cu_rej_uniform_avx2; +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_table.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_table.c new file mode 100644 index 0000000000..9bbc47146f --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/rej_uniform_table.c @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + +#include +#include "arith_native_x86_64.h" + +/* + * Lookup table used by rejection sampling of the public matrix. + * See autogen for details. + */ +ALIGN const uint8_t rej_uniform_table[256][8] = { + {-1, -1, -1, -1, -1, -1, -1, -1}, {0, -1, -1, -1, -1, -1, -1, -1}, + {2, -1, -1, -1, -1, -1, -1, -1}, {0, 2, -1, -1, -1, -1, -1, -1}, + {4, -1, -1, -1, -1, -1, -1, -1}, {0, 4, -1, -1, -1, -1, -1, -1}, + {2, 4, -1, -1, -1, -1, -1, -1}, {0, 2, 4, -1, -1, -1, -1, -1}, + {6, -1, -1, -1, -1, -1, -1, -1}, {0, 6, -1, -1, -1, -1, -1, -1}, + {2, 6, -1, -1, -1, -1, -1, -1}, {0, 2, 6, -1, -1, -1, -1, -1}, + {4, 6, -1, -1, -1, -1, -1, -1}, {0, 4, 6, -1, -1, -1, -1, -1}, + {2, 4, 6, -1, -1, -1, -1, -1}, {0, 2, 4, 6, -1, -1, -1, -1}, + {8, -1, -1, -1, -1, -1, -1, -1}, {0, 8, -1, -1, -1, -1, -1, -1}, + {2, 8, -1, -1, -1, -1, -1, -1}, {0, 2, 8, -1, -1, -1, -1, -1}, + {4, 8, -1, -1, -1, -1, -1, -1}, {0, 4, 8, -1, -1, -1, -1, -1}, + {2, 4, 8, -1, -1, -1, -1, -1}, {0, 2, 4, 8, -1, -1, -1, -1}, + {6, 8, -1, -1, -1, -1, -1, -1}, {0, 6, 8, -1, -1, -1, -1, -1}, + {2, 6, 8, -1, -1, -1, -1, -1}, {0, 2, 6, 8, -1, -1, -1, -1}, + {4, 6, 8, -1, -1, -1, -1, -1}, {0, 4, 6, 8, -1, -1, -1, -1}, + {2, 4, 6, 8, -1, -1, -1, -1}, {0, 2, 4, 6, 8, -1, -1, -1}, + {10, -1, -1, -1, -1, -1, -1, -1}, {0, 10, -1, -1, -1, -1, -1, -1}, + {2, 10, -1, -1, -1, -1, -1, -1}, {0, 2, 10, -1, -1, -1, -1, -1}, + {4, 10, -1, -1, -1, -1, -1, -1}, {0, 4, 10, -1, -1, -1, -1, -1}, + {2, 4, 10, -1, -1, -1, -1, -1}, {0, 2, 4, 10, -1, -1, -1, -1}, + {6, 10, -1, -1, -1, -1, -1, -1}, {0, 6, 10, -1, -1, -1, -1, -1}, + {2, 6, 10, -1, -1, -1, -1, -1}, {0, 2, 6, 10, -1, -1, -1, -1}, + {4, 6, 10, -1, -1, -1, -1, -1}, {0, 4, 6, 10, -1, -1, -1, -1}, + {2, 4, 6, 10, -1, -1, -1, -1}, {0, 2, 4, 6, 10, -1, -1, -1}, + {8, 10, -1, -1, -1, -1, -1, -1}, {0, 8, 10, -1, -1, -1, -1, -1}, + {2, 8, 10, -1, -1, -1, -1, -1}, {0, 2, 8, 10, -1, -1, -1, -1}, + {4, 8, 10, -1, -1, -1, -1, -1}, {0, 4, 8, 10, -1, -1, -1, -1}, + {2, 4, 8, 10, -1, -1, -1, -1}, {0, 2, 4, 8, 10, -1, -1, -1}, + {6, 8, 10, -1, -1, -1, -1, -1}, {0, 6, 8, 10, -1, -1, -1, -1}, + {2, 6, 8, 10, -1, -1, -1, -1}, {0, 2, 6, 8, 10, -1, -1, -1}, + {4, 6, 8, 10, -1, -1, -1, -1}, {0, 4, 6, 8, 10, -1, -1, -1}, + {2, 4, 6, 8, 10, -1, -1, -1}, {0, 2, 4, 6, 8, 10, -1, -1}, + {12, -1, -1, -1, -1, -1, -1, -1}, {0, 12, -1, -1, -1, -1, -1, -1}, + {2, 12, -1, -1, -1, -1, -1, -1}, {0, 2, 12, -1, -1, -1, -1, -1}, + {4, 12, -1, -1, -1, -1, -1, -1}, {0, 4, 12, -1, -1, -1, -1, -1}, + {2, 4, 12, -1, -1, -1, -1, -1}, {0, 2, 4, 12, -1, -1, -1, -1}, + {6, 12, -1, -1, -1, -1, -1, -1}, {0, 6, 12, -1, -1, -1, -1, -1}, + {2, 6, 12, -1, -1, -1, -1, -1}, {0, 2, 6, 12, -1, -1, -1, -1}, + {4, 6, 12, -1, -1, -1, -1, -1}, {0, 4, 6, 12, -1, -1, -1, -1}, + {2, 4, 6, 12, -1, -1, -1, -1}, {0, 2, 4, 6, 12, -1, -1, -1}, + {8, 12, -1, -1, -1, -1, -1, -1}, {0, 8, 12, -1, -1, -1, -1, -1}, + {2, 8, 12, -1, -1, -1, -1, -1}, {0, 2, 8, 12, -1, -1, -1, -1}, + {4, 8, 12, -1, -1, -1, -1, -1}, {0, 4, 8, 12, -1, -1, -1, -1}, + {2, 4, 8, 12, -1, -1, -1, -1}, {0, 2, 4, 8, 12, -1, -1, -1}, + {6, 8, 12, -1, -1, -1, -1, -1}, {0, 6, 8, 12, -1, -1, -1, -1}, + {2, 6, 8, 12, -1, -1, -1, -1}, {0, 2, 6, 8, 12, -1, -1, -1}, + {4, 6, 8, 12, -1, -1, -1, -1}, {0, 4, 6, 8, 12, -1, -1, -1}, + {2, 4, 6, 8, 12, -1, -1, -1}, {0, 2, 4, 6, 8, 12, -1, -1}, + {10, 12, -1, -1, -1, -1, -1, -1}, {0, 10, 12, -1, -1, -1, -1, -1}, + {2, 10, 12, -1, -1, -1, -1, -1}, {0, 2, 10, 12, -1, -1, -1, -1}, + {4, 10, 12, -1, -1, -1, -1, -1}, {0, 4, 10, 12, -1, -1, -1, -1}, + {2, 4, 10, 12, -1, -1, -1, -1}, {0, 2, 4, 10, 12, -1, -1, -1}, + {6, 10, 12, -1, -1, -1, -1, -1}, {0, 6, 10, 12, -1, -1, -1, -1}, + {2, 6, 10, 12, -1, -1, -1, -1}, {0, 2, 6, 10, 12, -1, -1, -1}, + {4, 6, 10, 12, -1, -1, -1, -1}, {0, 4, 6, 10, 12, -1, -1, -1}, + {2, 4, 6, 10, 12, -1, -1, -1}, {0, 2, 4, 6, 10, 12, -1, -1}, + {8, 10, 12, -1, -1, -1, -1, -1}, {0, 8, 10, 12, -1, -1, -1, -1}, + {2, 8, 10, 12, -1, -1, -1, -1}, {0, 2, 8, 10, 12, -1, -1, -1}, + {4, 8, 10, 12, -1, -1, -1, -1}, {0, 4, 8, 10, 12, -1, -1, -1}, + {2, 4, 8, 10, 12, -1, -1, -1}, {0, 2, 4, 8, 10, 12, -1, -1}, + {6, 8, 10, 12, -1, -1, -1, -1}, {0, 6, 8, 10, 12, -1, -1, -1}, + {2, 6, 8, 10, 12, -1, -1, -1}, {0, 2, 6, 8, 10, 12, -1, -1}, + {4, 6, 8, 10, 12, -1, -1, -1}, {0, 4, 6, 8, 10, 12, -1, -1}, + {2, 4, 6, 8, 10, 12, -1, -1}, {0, 2, 4, 6, 8, 10, 12, -1}, + {14, -1, -1, -1, -1, -1, -1, -1}, {0, 14, -1, -1, -1, -1, -1, -1}, + {2, 14, -1, -1, -1, -1, -1, -1}, {0, 2, 14, -1, -1, -1, -1, -1}, + {4, 14, -1, -1, -1, -1, -1, -1}, {0, 4, 14, -1, -1, -1, -1, -1}, + {2, 4, 14, -1, -1, -1, -1, -1}, {0, 2, 4, 14, -1, -1, -1, -1}, + {6, 14, -1, -1, -1, -1, -1, -1}, {0, 6, 14, -1, -1, -1, -1, -1}, + {2, 6, 14, -1, -1, -1, -1, -1}, {0, 2, 6, 14, -1, -1, -1, -1}, + {4, 6, 14, -1, -1, -1, -1, -1}, {0, 4, 6, 14, -1, -1, -1, -1}, + {2, 4, 6, 14, -1, -1, -1, -1}, {0, 2, 4, 6, 14, -1, -1, -1}, + {8, 14, -1, -1, -1, -1, -1, -1}, {0, 8, 14, -1, -1, -1, -1, -1}, + {2, 8, 14, -1, -1, -1, -1, -1}, {0, 2, 8, 14, -1, -1, -1, -1}, + {4, 8, 14, -1, -1, -1, -1, -1}, {0, 4, 8, 14, -1, -1, -1, -1}, + {2, 4, 8, 14, -1, -1, -1, -1}, {0, 2, 4, 8, 14, -1, -1, -1}, + {6, 8, 14, -1, -1, -1, -1, -1}, {0, 6, 8, 14, -1, -1, -1, -1}, + {2, 6, 8, 14, -1, -1, -1, -1}, {0, 2, 6, 8, 14, -1, -1, -1}, + {4, 6, 8, 14, -1, -1, -1, -1}, {0, 4, 6, 8, 14, -1, -1, -1}, + {2, 4, 6, 8, 14, -1, -1, -1}, {0, 2, 4, 6, 8, 14, -1, -1}, + {10, 14, -1, -1, -1, -1, -1, -1}, {0, 10, 14, -1, -1, -1, -1, -1}, + {2, 10, 14, -1, -1, -1, -1, -1}, {0, 2, 10, 14, -1, -1, -1, -1}, + {4, 10, 14, -1, -1, -1, -1, -1}, {0, 4, 10, 14, -1, -1, -1, -1}, + {2, 4, 10, 14, -1, -1, -1, -1}, {0, 2, 4, 10, 14, -1, -1, -1}, + {6, 10, 14, -1, -1, -1, -1, -1}, {0, 6, 10, 14, -1, -1, -1, -1}, + {2, 6, 10, 14, -1, -1, -1, -1}, {0, 2, 6, 10, 14, -1, -1, -1}, + {4, 6, 10, 14, -1, -1, -1, -1}, {0, 4, 6, 10, 14, -1, -1, -1}, + {2, 4, 6, 10, 14, -1, -1, -1}, {0, 2, 4, 6, 10, 14, -1, -1}, + {8, 10, 14, -1, -1, -1, -1, -1}, {0, 8, 10, 14, -1, -1, -1, -1}, + {2, 8, 10, 14, -1, -1, -1, -1}, {0, 2, 8, 10, 14, -1, -1, -1}, + {4, 8, 10, 14, -1, -1, -1, -1}, {0, 4, 8, 10, 14, -1, -1, -1}, + {2, 4, 8, 10, 14, -1, -1, -1}, {0, 2, 4, 8, 10, 14, -1, -1}, + {6, 8, 10, 14, -1, -1, -1, -1}, {0, 6, 8, 10, 14, -1, -1, -1}, + {2, 6, 8, 10, 14, -1, -1, -1}, {0, 2, 6, 8, 10, 14, -1, -1}, + {4, 6, 8, 10, 14, -1, -1, -1}, {0, 4, 6, 8, 10, 14, -1, -1}, + {2, 4, 6, 8, 10, 14, -1, -1}, {0, 2, 4, 6, 8, 10, 14, -1}, + {12, 14, -1, -1, -1, -1, -1, -1}, {0, 12, 14, -1, -1, -1, -1, -1}, + {2, 12, 14, -1, -1, -1, -1, -1}, {0, 2, 12, 14, -1, -1, -1, -1}, + {4, 12, 14, -1, -1, -1, -1, -1}, {0, 4, 12, 14, -1, -1, -1, -1}, + {2, 4, 12, 14, -1, -1, -1, -1}, {0, 2, 4, 12, 14, -1, -1, -1}, + {6, 12, 14, -1, -1, -1, -1, -1}, {0, 6, 12, 14, -1, -1, -1, -1}, + {2, 6, 12, 14, -1, -1, -1, -1}, {0, 2, 6, 12, 14, -1, -1, -1}, + {4, 6, 12, 14, -1, -1, -1, -1}, {0, 4, 6, 12, 14, -1, -1, -1}, + {2, 4, 6, 12, 14, -1, -1, -1}, {0, 2, 4, 6, 12, 14, -1, -1}, + {8, 12, 14, -1, -1, -1, -1, -1}, {0, 8, 12, 14, -1, -1, -1, -1}, + {2, 8, 12, 14, -1, -1, -1, -1}, {0, 2, 8, 12, 14, -1, -1, -1}, + {4, 8, 12, 14, -1, -1, -1, -1}, {0, 4, 8, 12, 14, -1, -1, -1}, + {2, 4, 8, 12, 14, -1, -1, -1}, {0, 2, 4, 8, 12, 14, -1, -1}, + {6, 8, 12, 14, -1, -1, -1, -1}, {0, 6, 8, 12, 14, -1, -1, -1}, + {2, 6, 8, 12, 14, -1, -1, -1}, {0, 2, 6, 8, 12, 14, -1, -1}, + {4, 6, 8, 12, 14, -1, -1, -1}, {0, 4, 6, 8, 12, 14, -1, -1}, + {2, 4, 6, 8, 12, 14, -1, -1}, {0, 2, 4, 6, 8, 12, 14, -1}, + {10, 12, 14, -1, -1, -1, -1, -1}, {0, 10, 12, 14, -1, -1, -1, -1}, + {2, 10, 12, 14, -1, -1, -1, -1}, {0, 2, 10, 12, 14, -1, -1, -1}, + {4, 10, 12, 14, -1, -1, -1, -1}, {0, 4, 10, 12, 14, -1, -1, -1}, + {2, 4, 10, 12, 14, -1, -1, -1}, {0, 2, 4, 10, 12, 14, -1, -1}, + {6, 10, 12, 14, -1, -1, -1, -1}, {0, 6, 10, 12, 14, -1, -1, -1}, + {2, 6, 10, 12, 14, -1, -1, -1}, {0, 2, 6, 10, 12, 14, -1, -1}, + {4, 6, 10, 12, 14, -1, -1, -1}, {0, 4, 6, 10, 12, 14, -1, -1}, + {2, 4, 6, 10, 12, 14, -1, -1}, {0, 2, 4, 6, 10, 12, 14, -1}, + {8, 10, 12, 14, -1, -1, -1, -1}, {0, 8, 10, 12, 14, -1, -1, -1}, + {2, 8, 10, 12, 14, -1, -1, -1}, {0, 2, 8, 10, 12, 14, -1, -1}, + {4, 8, 10, 12, 14, -1, -1, -1}, {0, 4, 8, 10, 12, 14, -1, -1}, + {2, 4, 8, 10, 12, 14, -1, -1}, {0, 2, 4, 8, 10, 12, 14, -1}, + {6, 8, 10, 12, 14, -1, -1, -1}, {0, 6, 8, 10, 12, 14, -1, -1}, + {2, 6, 8, 10, 12, 14, -1, -1}, {0, 2, 6, 8, 10, 12, 14, -1}, + {4, 6, 8, 10, 12, 14, -1, -1}, {0, 4, 6, 8, 10, 12, 14, -1}, + {2, 4, 6, 8, 10, 12, 14, -1}, {0, 2, 4, 6, 8, 10, 12, 14}, +}; + +#else + +/* Dummy declaration for compilers disliking empty compilation units */ +#define empty_cu_avx2_rej_uniform_table \ + MLKEM_NAMESPACE(empty_cu_avx2_rej_uniform_table) +int empty_cu_avx2_rej_uniform_table; +#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.S b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.S similarity index 81% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.S rename to src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.S index 18325ebec0..5e708748a8 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/shuffle.S +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.S @@ -1,9 +1,21 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "common.h" + +#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT) + #include "consts.h" -.include "fq.inc" -.include "shuffle.inc" +#include "fq.inc" +#include "shuffle.inc" -/* -nttpack_avx: +.global MLKEM_ASM_NAMESPACE(nttpack_avx2) +MLKEM_ASM_NAMESPACE(nttpack_avx2): #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -45,10 +57,8 @@ vmovdqa %ymm5,192(%rdi) vmovdqa %ymm11,224(%rdi) ret -*/ -.text -nttunpack128_avx: +nttunpack128_avx2: #load vmovdqa (%rdi),%ymm4 vmovdqa 32(%rdi),%ymm5 @@ -91,11 +101,11 @@ vmovdqa %ymm11,224(%rdi) ret -.global cdecl(nttunpack_avx) -cdecl(nttunpack_avx): -call nttunpack128_avx +.global MLKEM_ASM_NAMESPACE(nttunpack_avx2) +MLKEM_ASM_NAMESPACE(nttunpack_avx2): +call nttunpack128_avx2 add $256,%rdi -call nttunpack128_avx +call nttunpack128_avx2 ret ntttobytes128_avx: @@ -109,16 +119,6 @@ vmovdqa 160(%rsi),%ymm10 vmovdqa 192(%rsi),%ymm11 vmovdqa 224(%rsi),%ymm12 -#csubq -csubq 5,13 -csubq 6,13 -csubq 7,13 -csubq 8,13 -csubq 9,13 -csubq 10,13 -csubq 11,13 -csubq 12,13 - #bitpack vpsllw $12,%ymm6,%ymm4 vpor %ymm4,%ymm5,%ymm4 @@ -168,10 +168,10 @@ vmovdqu %ymm9,160(%rdi) ret -.global cdecl(ntttobytes_avx) -cdecl(ntttobytes_avx): +.global MLKEM_ASM_NAMESPACE(ntttobytes_avx2) +MLKEM_ASM_NAMESPACE(ntttobytes_avx2): #consts -vmovdqa _16XQ*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XQ*2(%rdx),%ymm0 call ntttobytes128_avx add $256,%rsi add $192,%rdi @@ -244,12 +244,14 @@ vmovdqa %ymm1,224(%rdi) ret -.global cdecl(nttfrombytes_avx) -cdecl(nttfrombytes_avx): +.global MLKEM_ASM_NAMESPACE(nttfrombytes_avx2) +MLKEM_ASM_NAMESPACE(nttfrombytes_avx2): #consts -vmovdqa _16XMASK*2(%rdx),%ymm0 +vmovdqa AVX2_BACKEND_DATA_OFFSET_16XMASK*2(%rdx),%ymm0 call nttfrombytes128_avx add $256,%rdi add $192,%rsi call nttfrombytes128_avx ret + +#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.inc b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.inc similarity index 55% rename from src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.inc rename to src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.inc index 73e9ffe03c..359807bd25 100644 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/shuffle.inc +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/shuffle.inc @@ -1,3 +1,8 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + .macro shuffle8 r0,r1,r2,r3 vperm2i128 $0x20,%ymm\r1,%ymm\r0,%ymm\r2 vperm2i128 $0x31,%ymm\r1,%ymm\r0,%ymm\r3 @@ -8,12 +13,19 @@ vpunpcklqdq %ymm\r1,%ymm\r0,%ymm\r2 vpunpckhqdq %ymm\r1,%ymm\r0,%ymm\r3 .endm +/* Shuffle r0=(a0,b0,c0,d0,...), r1=(a1,b1,c1,d1,...) into */ +/* r2 = (a0,b0,a1,b1,e0,f0,e1,f1,...) */ +/* r3 = (c0,d0,c1,d1,g0,h0,g1,h1,...) */ .macro shuffle2 r0,r1,r2,r3 -#vpsllq $32,%ymm\r1,%ymm\r2 +/* r2=(a1,b1,a1,b1,e1,f1,e1,f1,...) */ vmovsldup %ymm\r1,%ymm\r2 +/* Conditional move */ +/* 0xAA = 0b10101010 */ +/* r2=(a0,b0,a1,b1,e0,f0,e1,f1,...) */ vpblendd $0xAA,%ymm\r2,%ymm\r0,%ymm\r2 +/* r0=(c0,d0,0,0,g0,h0,0,0,...) */ vpsrlq $32,%ymm\r0,%ymm\r0 -#vmovshdup %ymm\r0,%ymm\r0 +/* r3=(c0,d0,c1,d1,g0,h0,g1,h1,...) */ vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/x86_64_zetas.i b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/x86_64_zetas.i new file mode 100644 index 0000000000..26d582ee53 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/x86_64/src/x86_64_zetas.i @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +/* + * Table of zeta values used in the AVX2 NTTs + * See autogen for details. + */ + +31498, 31498, 31498, 31498, -758, -758, -758, -758, 0, 0, 0, 0, 0, 0, 0, 0, + 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, + 14745, 14745, 14745, 14745, 14745, -359, -359, -359, -359, -359, -359, -359, + -359, -359, -359, -359, -359, -359, -359, -359, -359, 13525, 13525, 13525, + 13525, 13525, 13525, 13525, 13525, -12402, -12402, -12402, -12402, -12402, + -12402, -12402, -12402, 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, + 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, -20907, -20907, -20907, + -20907, 27758, 27758, 27758, 27758, -3799, -3799, -3799, -3799, -15690, + -15690, -15690, -15690, -171, -171, -171, -171, 622, 622, 622, 622, 1577, + 1577, 1577, 1577, 182, 182, 182, 182, -5827, -5827, 17363, 17363, -26360, + -26360, -29057, -29057, 5571, 5571, -1102, -1102, 21438, 21438, -26242, + -26242, 573, 573, -1325, -1325, 264, 264, 383, 383, -829, -829, 1458, 1458, + -1602, -1602, -130, -130, -5689, -6516, 1496, 30967, -23565, 20179, 20710, + 25080, -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, 1223, 652, + -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, -1618, -1162, + 126, 1469, -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, 32010, + 10631, 29175, -28762, -18486, 17560, -14430, -5276, -1103, 555, -1251, 1550, + 422, 177, -291, 1574, -246, 1159, -777, -602, -1590, -872, 418, -156, 11182, + 13387, -14233, -21655, 13131, -4587, 23092, 5493, -32502, 30317, -18741, + 12639, 20100, 18525, 19529, -12619, 430, 843, 871, 105, 587, -235, -460, + 1653, 778, -147, 1483, 1119, 644, 349, 329, -75, 787, 787, 787, 787, 787, + 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, 787, -1517, -1517, -1517, + -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, + -1517, -1517, 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, + -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, 287, 287, + 287, 287, 287, 287, 287, 287, 202, 202, 202, 202, 202, 202, 202, 202, 10690, + 10690, 10690, 10690, 1358, 1358, 1358, 1358, -11202, -11202, -11202, -11202, + 31164, 31164, 31164, 31164, 962, 962, 962, 962, -1202, -1202, -1202, -1202, + -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, -28073, -28073, 24313, + 24313, -10532, -10532, 8800, 8800, 18426, 18426, 8859, 8859, 26675, 26675, + -16163, -16163, -681, -681, 1017, 1017, 732, 732, 608, 608, -1542, -1542, + 411, 411, -205, -205, -1571, -1571, 19883, -28250, -15887, -8898, -28309, + 9075, -30199, 18249, 13426, 14017, -29156, -12757, 16832, 4311, -24155, + -17915, -853, -90, -271, 830, 107, -1421, -247, -951, -398, 961, -1508, + -725, 448, -1065, 677, -1275, -31183, 25435, -7382, 24391, -20927, 10946, + 24214, 16989, 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, 817, + 603, 1322, -1465, -1215, 1218, -874, -1187, -1185, -1278, -1510, -870, -108, + 996, 958, 1522, 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, + -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, 1097, 610, + -1285, 384, -136, -1335, 220, -1659, -1530, 794, -854, 478, -308, 991, + -1460, 1628, diff --git a/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/zetas.c b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/zetas.c new file mode 100644 index 0000000000..1a26e0dd59 --- /dev/null +++ b/src/kem/ml_kem/mlkem-native_ml-kem-768_x86_64/zetas.c @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2024 The mlkem-native project authors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * WARNING: This file is auto-generated from scripts/autogen + * Do not modify it directly. + */ + +#include "ntt.h" + +/* + * Table of zeta values used in the reference NTT and inverse NTT. + * See autogen for details. + */ +ALIGN const int16_t zetas[128] = { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628, +}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/align.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/align.h deleted file mode 100644 index 3463866f37..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/align.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef ALIGN_H -#define ALIGN_H - -#include -#include - -#define ALIGNED_UINT8(N) \ - union { \ - uint8_t coeffs[N]; \ - __m256i vec[(N+31)/32]; \ - } - -#define ALIGNED_INT16(N) \ - union { \ - int16_t coeffs[N]; \ - __m256i vec[(N+15)/16]; \ - } - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/api.h deleted file mode 100644 index a154e80f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_avx2_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_avx2_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_avx2_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_avx2_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_avx2_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_avx2_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_avx2_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_avx2_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_avx2_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_avx2_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_avx2_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_avx2_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_avx2_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_avx2_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_avx2_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_avx2_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_avx2_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_avx2_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.c deleted file mode 100644 index dad473c79e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.c +++ /dev/null @@ -1,144 +0,0 @@ -#include -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd2(poly * restrict r, const __m256i buf[2*KYBER_N/128]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask55 = _mm256_set1_epi32(0x55555555); - const __m256i mask33 = _mm256_set1_epi32(0x33333333); - const __m256i mask03 = _mm256_set1_epi32(0x03030303); - const __m256i mask0F = _mm256_set1_epi32(0x0F0F0F0F); - - for(i = 0; i < KYBER_N/64; i++) { - f0 = _mm256_load_si256(&buf[i]); - - f1 = _mm256_srli_epi16(f0, 1); - f0 = _mm256_and_si256(mask55, f0); - f1 = _mm256_and_si256(mask55, f1); - f0 = _mm256_add_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 2); - f0 = _mm256_and_si256(mask33, f0); - f1 = _mm256_and_si256(mask33, f1); - f0 = _mm256_add_epi8(f0, mask33); - f0 = _mm256_sub_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 4); - f0 = _mm256_and_si256(mask0F, f0); - f1 = _mm256_and_si256(mask0F, f1); - f0 = _mm256_sub_epi8(f0, mask03); - f1 = _mm256_sub_epi8(f1, mask03); - - f2 = _mm256_unpacklo_epi8(f0, f1); - f3 = _mm256_unpackhi_epi8(f0, f1); - - f0 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f2)); - f1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f2,1)); - f2 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f3)); - f3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f3,1)); - - _mm256_store_si256(&r->vec[4*i+0], f0); - _mm256_store_si256(&r->vec[4*i+1], f2); - _mm256_store_si256(&r->vec[4*i+2], f1); - _mm256_store_si256(&r->vec[4*i+3], f3); - } -} - -#if KYBER_ETA1 == 3 -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3 -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd3(poly * restrict r, const uint8_t buf[3*KYBER_N/4+8]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask249 = _mm256_set1_epi32(0x249249); - const __m256i mask6DB = _mm256_set1_epi32(0x6DB6DB); - const __m256i mask07 = _mm256_set1_epi32(7); - const __m256i mask70 = _mm256_set1_epi32(7 << 16); - const __m256i mask3 = _mm256_set1_epi16(3); - const __m256i shufbidx = _mm256_set_epi8(-1,15,14,13,-1,12,11,10,-1, 9, 8, 7,-1, 6, 5, 4, - -1,11,10, 9,-1, 8, 7, 6,-1, 5, 4, 3,-1, 2, 1, 0); - - for(i = 0; i < KYBER_N/32; i++) { - f0 = _mm256_loadu_si256((__m256i *)&buf[24*i]); - f0 = _mm256_permute4x64_epi64(f0,0x94); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - - f1 = _mm256_srli_epi32(f0,1); - f2 = _mm256_srli_epi32(f0,2); - f0 = _mm256_and_si256(mask249,f0); - f1 = _mm256_and_si256(mask249,f1); - f2 = _mm256_and_si256(mask249,f2); - f0 = _mm256_add_epi32(f0,f1); - f0 = _mm256_add_epi32(f0,f2); - - f1 = _mm256_srli_epi32(f0,3); - f0 = _mm256_add_epi32(f0,mask6DB); - f0 = _mm256_sub_epi32(f0,f1); - - f1 = _mm256_slli_epi32(f0,10); - f2 = _mm256_srli_epi32(f0,12); - f3 = _mm256_srli_epi32(f0, 2); - f0 = _mm256_and_si256(f0,mask07); - f1 = _mm256_and_si256(f1,mask70); - f2 = _mm256_and_si256(f2,mask07); - f3 = _mm256_and_si256(f3,mask70); - f0 = _mm256_add_epi16(f0,f1); - f1 = _mm256_add_epi16(f2,f3); - f0 = _mm256_sub_epi16(f0,mask3); - f1 = _mm256_sub_epi16(f1,mask3); - - f2 = _mm256_unpacklo_epi32(f0,f1); - f3 = _mm256_unpackhi_epi32(f0,f1); - - f0 = _mm256_permute2x128_si256(f2,f3,0x20); - f1 = _mm256_permute2x128_si256(f2,f3,0x31); - - _mm256_store_si256(&r->vec[2*i+0], f0); - _mm256_store_si256(&r->vec[2*i+1], f1); - } -} -#endif - -/* buf 32 bytes longer for cbd3 */ -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, (uint8_t *)buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.h deleted file mode 100644 index 05788e06b4..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/cbd.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.c deleted file mode 100644 index 84e596893d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.c +++ /dev/null @@ -1,121 +0,0 @@ -#include "align.h" -#include "params.h" -#include "consts.h" - -#define Q KYBER_Q -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 -#define V 20159 // floor(2^26/q + 0.5) -#define FHI 1441 // mont^2/128 -#define FLO -10079 // qinv*FHI -#define MONTSQHI 1353 // mont^2 -#define MONTSQLO 20553 // qinv*MONTSQHI -#define MASK 4095 -#define SHIFT 32 - -const qdata_t qdata = {{ -#define _16XQ 0 - Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, - -#define _16XQINV 16 - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - -#define _16XV 32 - V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, - -#define _16XFLO 48 - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - -#define _16XFHI 64 - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - -#define _16XMONTSQLO 80 - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - -#define _16XMONTSQHI 96 - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - -#define _16XMASK 112 - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - -#define _REVIDXB 128 - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - -#define _REVIDXD 144 - 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0, 0, - -#define _ZETAS_EXP 160 - 31498, 31498, 31498, 31498, -758, -758, -758, -758, - 5237, 5237, 5237, 5237, 1397, 1397, 1397, 1397, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - 13525, 13525, 13525, 13525, 13525, 13525, 13525, 13525, - -12402, -12402, -12402, -12402, -12402, -12402, -12402, -12402, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - -20907, -20907, -20907, -20907, 27758, 27758, 27758, 27758, - -3799, -3799, -3799, -3799, -15690, -15690, -15690, -15690, - -171, -171, -171, -171, 622, 622, 622, 622, - 1577, 1577, 1577, 1577, 182, 182, 182, 182, - -5827, -5827, 17363, 17363, -26360, -26360, -29057, -29057, - 5571, 5571, -1102, -1102, 21438, 21438, -26242, -26242, - 573, 573, -1325, -1325, 264, 264, 383, 383, - -829, -829, 1458, 1458, -1602, -1602, -130, -130, - -5689, -6516, 1496, 30967, -23565, 20179, 20710, 25080, - -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, - 32010, 10631, 29175, -28762, -18486, 17560, -14430, -5276, - -1103, 555, -1251, 1550, 422, 177, -291, 1574, - -246, 1159, -777, -602, -1590, -872, 418, -156, - 11182, 13387, -14233, -21655, 13131, -4587, 23092, 5493, - -32502, 30317, -18741, 12639, 20100, 18525, 19529, -12619, - 430, 843, 871, 105, 587, -235, -460, 1653, - 778, -147, 1483, 1119, 644, 349, 329, -75, - 787, 787, 787, 787, 787, 787, 787, 787, - 787, 787, 787, 787, 787, 787, 787, 787, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, - -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, - 287, 287, 287, 287, 287, 287, 287, 287, - 202, 202, 202, 202, 202, 202, 202, 202, - 10690, 10690, 10690, 10690, 1358, 1358, 1358, 1358, - -11202, -11202, -11202, -11202, 31164, 31164, 31164, 31164, - 962, 962, 962, 962, -1202, -1202, -1202, -1202, - -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, - -28073, -28073, 24313, 24313, -10532, -10532, 8800, 8800, - 18426, 18426, 8859, 8859, 26675, 26675, -16163, -16163, - -681, -681, 1017, 1017, 732, 732, 608, 608, - -1542, -1542, 411, 411, -205, -205, -1571, -1571, - 19883, -28250, -15887, -8898, -28309, 9075, -30199, 18249, - 13426, 14017, -29156, -12757, 16832, 4311, -24155, -17915, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -31183, 25435, -7382, 24391, -20927, 10946, 24214, 16989, - 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, - 817, 603, 1322, -1465, -1215, 1218, -874, -1187, - -1185, -1278, -1510, -870, -108, 996, 958, 1522, - 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, - -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, - 1097, 610, -1285, 384, -136, -1335, 220, -1659, - -1530, 794, -854, 478, -308, 991, -1460, 1628, - -#define _16XSHIFT 624 - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT -}}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.h deleted file mode 100644 index f95899cd8e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/consts.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef CONSTS_H -#define CONSTS_H - -#include "params.h" - -#define _16XQ 0 -#define _16XQINV 16 -#define _16XV 32 -#define _16XFLO 48 -#define _16XFHI 64 -#define _16XMONTSQLO 80 -#define _16XMONTSQHI 96 -#define _16XMASK 112 -#define _REVIDXB 128 -#define _REVIDXD 144 -#define _ZETAS_EXP 160 -#define _16XSHIFT 624 - -/* The C ABI on MacOS exports all symbols with a leading - * underscore. This means that any symbols we refer to from - * C files (functions) can't be found, and all symbols we - * refer to from ASM also can't be found. - * - * This define helps us get around this - */ -#ifdef __ASSEMBLER__ -#if defined(__WIN32__) || defined(__APPLE__) -#define decorate(s) _##s -#define cdecl2(s) decorate(s) -#define cdecl(s) cdecl2(KYBER_NAMESPACE(##s)) -#else -#define cdecl(s) KYBER_NAMESPACE(##s) -#endif -#endif - -#ifndef __ASSEMBLER__ -#include "align.h" -typedef ALIGNED_INT16(640) qdata_t; -#define qdata KYBER_NAMESPACE(qdata) -extern const qdata_t qdata; -#endif - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.c deleted file mode 100644 index c4b2b3a89f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/indcpa.c +++ /dev/null @@ -1,568 +0,0 @@ -#include -#include -#include -#include -#include "align.h" -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "cbd.h" -#include "rejsample.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk and the -* public seed used to generate the matrix A. -* The polynomial coefficients in pk are assumed to -* lie in the invertal [0,q], i.e. pk must be reduced -* by polyvec_reduce(). -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key. -* The polynomial coefficients in sk are assumed to -* lie in the invertal [0,q], i.e. sk must be reduced -* by polyvec_reduce(). -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v. -* The polynomial coefficients in b and v are assumed to -* lie in the invertal [0,q], i.e. b and v must be reduced -* by polyvec_reduce() and poly_reduce(), respectively. -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output array -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos <= buflen - 3) { // buflen is always at least 3 - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if KYBER_K == 2 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 1; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 1; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[1].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[1].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[1].vec[0]); - poly_nttunpack(&a[1].vec[1]); - shake128x4_inc_ctx_release(&state); -} -#elif KYBER_K == 3 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128incctx state1x; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 0; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 0; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[0].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[0].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[0].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[0].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[0].vec[2]); - poly_nttunpack(&a[1].vec[0]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 2; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 2; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 2; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 2; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[1].vec[1].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[1].vec[2].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[2].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[2].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[1].vec[1].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[1].vec[2].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[2].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[2].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - shake128x4_inc_ctx_release(&state); - - poly_nttunpack(&a[1].vec[1]); - poly_nttunpack(&a[1].vec[2]); - poly_nttunpack(&a[2].vec[0]); - poly_nttunpack(&a[2].vec[1]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - buf[0].coeffs[32] = 2; - buf[0].coeffs[33] = 2; - - shake128_inc_init(&state1x); - shake128_absorb_once(&state1x, buf[0].coeffs, 34); - shake128_squeezeblocks(buf[0].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state1x); - ctr0 = rej_uniform_avx(a[2].vec[2].coeffs, buf[0].coeffs); - while(ctr0 < KYBER_N) { - shake128_squeezeblocks(buf[0].coeffs, 1, &state1x); - ctr0 += rej_uniform(a[2].vec[2].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - } - shake128_inc_ctx_release(&state1x); - - poly_nttunpack(&a[2].vec[2]); -} -#elif KYBER_K == 4 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int i, ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128x4_inc_init(&state); - - for(i=0;i<4;i++) { - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = i; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = i; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = i; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = i; - buf[3].coeffs[33] = 3; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = i; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = i; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = i; - buf[3].coeffs[32] = 3; - buf[3].coeffs[33] = i; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[i].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[i].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[i].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[i].vec[3].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[i].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[i].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[i].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[i].vec[3].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[i].vec[0]); - poly_nttunpack(&a[i].vec[1]); - poly_nttunpack(&a[i].vec[2]); - poly_nttunpack(&a[i].vec[3]); - } - shake128x4_inc_ctx_release(&state); -} -#endif - -/************************************************* -* Name: indcpa_keypair_derand -* -* Description: Generates public and private key for the CPA-secure -* public-key encryption scheme underlying Kyber -* -* Arguments: - uint8_t *pk: pointer to output public key -* (of length KYBER_INDCPA_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (of length KYBER_INDCPA_SECRETKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (of length KYBER_SYMBYTES bytes) -**************************************************/ -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]) -{ - unsigned int i; - uint8_t buf[2*KYBER_SYMBYTES]; - const uint8_t *publicseed = buf; - const uint8_t *noiseseed = buf + KYBER_SYMBYTES; - polyvec a[KYBER_K], e, pkpv, skpv; - - memcpy(buf, coins, KYBER_SYMBYTES); - buf[KYBER_SYMBYTES] = KYBER_K; - hash_g(buf, buf, KYBER_SYMBYTES+1); - - gen_a(a, publicseed); - -#if KYBER_K == 2 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, e.vec+0, e.vec+1, noiseseed, 0, 1, 2, 3); -#elif KYBER_K == 3 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, e.vec+0, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+1, e.vec+2, pkpv.vec+0, pkpv.vec+1, noiseseed, 4, 5, 6, 7); -#elif KYBER_K == 4 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, skpv.vec+3, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+0, e.vec+1, e.vec+2, e.vec+3, noiseseed, 4, 5, 6, 7); -#endif - - polyvec_ntt(&skpv); - polyvec_reduce(&skpv); - polyvec_ntt(&e); - - // matrix-vector multiplication - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/invntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/invntt.S deleted file mode 100644 index 76d4189996..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/invntt.S +++ /dev/null @@ -1,193 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" -.include "fq.inc" - -.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 -vpsubw %ymm\rl0,%ymm\rh0,%ymm12 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 -vpsubw %ymm\rl1,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl0,%ymm12,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 -vpsubw %ymm\rl2,%ymm\rh2,%ymm14 - -vpmullw %ymm\zl0,%ymm13,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 -vpsubw %ymm\rl3,%ymm\rh3,%ymm15 - -vpmullw %ymm\zl1,%ymm14,%ymm\rh2 -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 -vpmullw %ymm\zl1,%ymm15,%ymm\rh3 - -vpmulhw %ymm\zh0,%ymm12,%ymm12 -vpmulhw %ymm\zh0,%ymm13,%ymm13 - -vpmulhw %ymm\zh1,%ymm14,%ymm14 -vpmulhw %ymm\zh1,%ymm15,%ymm15 - -vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 - -vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 - -# - -# - -vpsubw %ymm\rh0,%ymm12,%ymm\rh0 - -vpsubw %ymm\rh1,%ymm13,%ymm\rh1 - -vpsubw %ymm\rh2,%ymm14,%ymm\rh2 -vpsubw %ymm\rh3,%ymm15,%ymm\rh3 -.endm - -.macro intt_levels0t5 off -/* level 0 */ -vmovdqa _16XFLO*2(%rsi),%ymm2 -vmovdqa _16XFHI*2(%rsi),%ymm3 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -fqmulprecomp 2,3,4 -fqmulprecomp 2,3,6 -fqmulprecomp 2,3,5 -fqmulprecomp 2,3,7 - -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 - -fqmulprecomp 2,3,8 -fqmulprecomp 2,3,10 -fqmulprecomp 2,3,9 -fqmulprecomp 2,3,11 - -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm12 -vpshufb %ymm12,%ymm15,%ymm15 -vpshufb %ymm12,%ymm1,%ymm1 -vpshufb %ymm12,%ymm2,%ymm2 -vpshufb %ymm12,%ymm3,%ymm3 - -butterfly 4,5,8,9,6,7,10,11,15,1,2,3 - -/* level 1 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm1 -vpshufb %ymm1,%ymm2,%ymm2 -vpshufb %ymm1,%ymm3,%ymm3 - -butterfly 4,5,6,7,8,9,10,11,2,2,3,3 - -shuffle1 4,5,3,5 -shuffle1 6,7,4,7 -shuffle1 8,9,6,9 -shuffle1 10,11,8,11 - -/* level 2 */ -vmovdqa _REVIDXD*2(%rsi),%ymm12 -vpermd (_ZETAS_EXP+(1-\off)*224+112)*2(%rsi),%ymm12,%ymm2 -vpermd (_ZETAS_EXP+(1-\off)*224+128)*2(%rsi),%ymm12,%ymm10 - -butterfly 3,4,6,8,5,7,9,11,2,2,10,10 - -vmovdqa _16XV*2(%rsi),%ymm1 -red16 3 - -shuffle2 3,4,10,4 -shuffle2 6,8,3,8 -shuffle2 5,7,6,7 -shuffle2 9,11,5,11 - -/* level 3 */ -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+80)*2(%rsi),%ymm2 -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+96)*2(%rsi),%ymm9 - -butterfly 10,3,6,5,4,8,7,11,2,2,9,9 - -shuffle4 10,3,9,3 -shuffle4 6,5,10,5 -shuffle4 4,8,6,8 -shuffle4 7,11,4,11 - -/* level 4 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+48)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+64)*2(%rsi),%ymm7 - -butterfly 9,10,6,4,3,5,8,11,2,2,7,7 - -red16 9 - -shuffle8 9,10,7,10 -shuffle8 6,4,9,4 -shuffle8 3,5,6,5 -shuffle8 8,11,3,11 - -/* level 5 */ -vmovdqa (_ZETAS_EXP+(1-\off)*224+16)*2(%rsi),%ymm2 -vmovdqa (_ZETAS_EXP+(1-\off)*224+32)*2(%rsi),%ymm8 - -butterfly 7,9,6,3,10,4,5,11,2,2,8,8 - -vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.macro intt_level6 off -/* level 6 */ -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm2 - -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm3 - -butterfly 4,5,6,7,8,9,10,11 - -.if \off == 0 -red16 4 -.endif - -vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.text -.global cdecl(invntt_avx) -cdecl(invntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -intt_levels0t5 0 -intt_levels0t5 1 - -intt_level6 0 -intt_level6 1 -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.S deleted file mode 100644 index 0ce7b41297..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.S +++ /dev/null @@ -1,189 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" - -.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 -vpmullw %ymm\zl0,%ymm\rh0,%ymm12 -vpmullw %ymm\zl0,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl1,%ymm\rh2,%ymm14 -vpmullw %ymm\zl1,%ymm\rh3,%ymm15 - -vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 -vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 -.endm - -.macro reduce -vpmulhw %ymm0,%ymm12,%ymm12 -vpmulhw %ymm0,%ymm13,%ymm13 - -vpmulhw %ymm0,%ymm14,%ymm14 -vpmulhw %ymm0,%ymm15,%ymm15 -.endm - -.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln -vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 - -vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 -vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 - -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 -vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 - -vpsubw %ymm12,%ymm\rln,%ymm\rln -vpaddw %ymm12,%ymm\rh0,%ymm\rh0 -vpsubw %ymm13,%ymm\rl0,%ymm\rl0 - -vpaddw %ymm13,%ymm\rh1,%ymm\rh1 -vpsubw %ymm14,%ymm\rl1,%ymm\rl1 -vpaddw %ymm14,%ymm\rh2,%ymm\rh2 - -vpsubw %ymm15,%ymm\rl2,%ymm\rl2 -vpaddw %ymm15,%ymm\rh3,%ymm\rh3 -.endm - -.macro level0 off -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm15 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.macro levels1t6 off -/* level 1 */ -vmovdqa (_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 -vmovdqa (_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -/* level 2 */ -shuffle8 5,10,7,10 -shuffle8 6,11,5,11 - -vmovdqa (_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 - -mul 7,10,5,11 - -shuffle8 3,8,6,8 -shuffle8 4,9,3,9 - -reduce -update 4,6,8,3,9,7,10,5,11 - -/* level 3 */ -shuffle4 8,5,9,5 -shuffle4 3,11,8,11 - -vmovdqa (_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 - -mul 9,5,8,11 - -shuffle4 4,7,3,7 -shuffle4 6,10,4,10 - -reduce -update 6,3,7,4,10,9,5,8,11 - -/* level 4 */ -shuffle2 7,8,10,8 -shuffle2 4,11,7,11 - -vmovdqa (_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 - -mul 10,8,7,11 - -shuffle2 6,9,4,9 -shuffle2 3,5,6,5 - -reduce -update 3,4,9,6,5,10,8,7,11 - -/* level 5 */ -shuffle1 9,7,5,7 -shuffle1 6,11,9,11 - -vmovdqa (_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 - -mul 5,7,9,11 - -shuffle1 3,10,6,10 -shuffle1 4,8,3,8 - -reduce -update 4,6,10,3,8,5,7,9,11 - -/* level 6 */ -vmovdqa (_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 -vmovdqa (_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 -vmovdqa (_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 - -mul 10,3,9,11,14,15,8,2 - -reduce -update 8,4,6,5,7,10,3,9,11 - -vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.text -.global cdecl(ntt_avx) -cdecl(ntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -level0 0 -level0 1 - -levels1t6 0 -levels1t6 1 - -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.h deleted file mode 100644 index a4f48e343b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/ntt.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include - -#define ntt_avx KYBER_NAMESPACE(ntt_avx) -void ntt_avx(__m256i *r, const __m256i *qdata); -#define invntt_avx KYBER_NAMESPACE(invntt_avx) -void invntt_avx(__m256i *r, const __m256i *qdata); - -#define nttpack_avx KYBER_NAMESPACE(nttpack_avx) -void nttpack_avx(__m256i *r, const __m256i *qdata); -#define nttunpack_avx KYBER_NAMESPACE(nttunpack_avx) -void nttunpack_avx(__m256i *r, const __m256i *qdata); - -#define basemul_avx KYBER_NAMESPACE(basemul_avx) -void basemul_avx(__m256i *r, - const __m256i *a, - const __m256i *b, - const __m256i *qdata); - -#define ntttobytes_avx KYBER_NAMESPACE(ntttobytes_avx) -void ntttobytes_avx(uint8_t *r, const __m256i *a, const __m256i *qdata); -#define nttfrombytes_avx KYBER_NAMESPACE(nttfrombytes_avx) -void nttfrombytes_avx(__m256i *r, const uint8_t *a, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/params.h deleted file mode 100644 index ecfabce4a5..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/params.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - -//#define KYBER_90S /* Uncomment this if you want the 90S variant */ - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber512_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_avx2_##s -#endif -#elif (KYBER_K == 3) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber768_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_avx2_##s -#endif -#elif (KYBER_K == 4) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber1024_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_avx2_##s -#endif -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.c deleted file mode 100644 index 681fd6d23e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.c +++ /dev/null @@ -1,519 +0,0 @@ -#include -#include -#include -#include "align.h" -#include "fips202x4.h" -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -#if (KYBER_POLYCOMPRESSEDBYTES == 128) -void poly_compress(uint8_t r[128], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 9); - const __m256i mask = _mm256_set1_epi16(15); - const __m256i shift2 = _mm256_set1_epi16((16 << 8) + 1); - const __m256i permdidx = _mm256_set_epi32(7,3,6,2,5,1,4,0); - - for(i=0;ivec[4*i+0]); - f1 = _mm256_load_si256(&a->vec[4*i+1]); - f2 = _mm256_load_si256(&a->vec[4*i+2]); - f3 = _mm256_load_si256(&a->vec[4*i+3]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f2 = _mm256_mulhi_epi16(f2,v); - f3 = _mm256_mulhi_epi16(f3,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f2 = _mm256_mulhrs_epi16(f2,shift1); - f3 = _mm256_mulhrs_epi16(f3,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f2 = _mm256_and_si256(f2,mask); - f3 = _mm256_and_si256(f3,mask); - f0 = _mm256_packus_epi16(f0,f1); - f2 = _mm256_packus_epi16(f2,f3); - f0 = _mm256_maddubs_epi16(f0,shift2); - f2 = _mm256_maddubs_epi16(f2,shift2); - f0 = _mm256_packus_epi16(f0,f2); - f0 = _mm256_permutevar8x32_epi32(f0,permdidx); - _mm256_storeu_si256((__m256i *)&r[32*i],f0); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[128]) -{ - unsigned int i; - __m128i t; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(7,7,7,7,6,6,6,6,5,5,5,5,4,4,4,4, - 3,3,3,3,2,2,2,2,1,1,1,1,0,0,0,0); - const __m256i mask = _mm256_set1_epi32(0x00F0000F); - const __m256i shift = _mm256_set1_epi32((128 << 16) + 2048); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) -void poly_compress(uint8_t r[160], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 10); - const __m256i mask = _mm256_set1_epi16(31); - const __m256i shift2 = _mm256_set1_epi16((32 << 8) + 1); - const __m256i shift3 = _mm256_set1_epi32((1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8,-1,-1,-1,-1,-1, 4, 3, 2, 1, 0,-1,12,11,10, 9, - -1,12,11,10, 9, 8,-1,-1,-1,-1,-1 ,4, 3, 2, 1, 0); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f0 = _mm256_packus_epi16(f0,f1); - f0 = _mm256_maddubs_epi16(f0,shift2); // a0 a1 a2 a3 b0 b1 b2 b3 a4 a5 a6 a7 b4 b5 b6 b7 - f0 = _mm256_madd_epi16(f0,shift3); // a0 a1 b0 b1 a2 a3 b2 b3 - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srlv_epi64(f0,sllvdidx); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[160]) -{ - unsigned int i; - __m128i t; - __m256i f; - int16_t ti; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(9,9,9,8,8,8,8,7,7,6,6,6,6,5,5,5, - 4,4,4,3,3,3,3,2,2,1,1,1,1,0,0,0); - const __m256i mask = _mm256_set_epi16(248,1984,62,496,3968,124,992,31, - 248,1984,62,496,3968,124,992,31); - const __m256i shift = _mm256_set_epi16(128,16,512,64,8,256,32,1024, - 128,16,512,64,8,256,32,1024); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial in NTT representation. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). The coefficients are orderd as output by -* poly_ntt(); the serialized output coefficients are in bitreversed -* order. -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - ntttobytes_avx(r, a->vec, qdata.vec); -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - nttfrombytes_avx(r->vec, a, qdata.vec); -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly * restrict r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ -#if (KYBER_INDCPA_MSGBYTES != 32) -#error "KYBER_INDCPA_MSGBYTES must be equal to 32!" -#endif - __m256i f, g0, g1, g2, g3, h0, h1, h2, h3; - const __m256i shift = _mm256_broadcastsi128_si256(_mm_set_epi32(0,1,2,3)); - const __m256i idx = _mm256_broadcastsi128_si256(_mm_set_epi8(15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0)); - const __m256i hqs = _mm256_set1_epi16((KYBER_Q+1)/2); - -#define FROMMSG64(i) \ - g3 = _mm256_shuffle_epi32(f,0x55*i); \ - g3 = _mm256_sllv_epi32(g3,shift); \ - g3 = _mm256_shuffle_epi8(g3,idx); \ - g0 = _mm256_slli_epi16(g3,12); \ - g1 = _mm256_slli_epi16(g3,8); \ - g2 = _mm256_slli_epi16(g3,4); \ - g0 = _mm256_srai_epi16(g0,15); \ - g1 = _mm256_srai_epi16(g1,15); \ - g2 = _mm256_srai_epi16(g2,15); \ - g3 = _mm256_srai_epi16(g3,15); \ - g0 = _mm256_and_si256(g0,hqs); /* 19 18 17 16 3 2 1 0 */ \ - g1 = _mm256_and_si256(g1,hqs); /* 23 22 21 20 7 6 5 4 */ \ - g2 = _mm256_and_si256(g2,hqs); /* 27 26 25 24 11 10 9 8 */ \ - g3 = _mm256_and_si256(g3,hqs); /* 31 30 29 28 15 14 13 12 */ \ - h0 = _mm256_unpacklo_epi64(g0,g1); \ - h2 = _mm256_unpackhi_epi64(g0,g1); \ - h1 = _mm256_unpacklo_epi64(g2,g3); \ - h3 = _mm256_unpackhi_epi64(g2,g3); \ - g0 = _mm256_permute2x128_si256(h0,h1,0x20); \ - g2 = _mm256_permute2x128_si256(h0,h1,0x31); \ - g1 = _mm256_permute2x128_si256(h2,h3,0x20); \ - g3 = _mm256_permute2x128_si256(h2,h3,0x31); \ - _mm256_store_si256(&r->vec[0+2*i+0],g0); \ - _mm256_store_si256(&r->vec[0+2*i+1],g1); \ - _mm256_store_si256(&r->vec[8+2*i+0],g2); \ - _mm256_store_si256(&r->vec[8+2*i+1],g3) - - f = _mm256_loadu_si256((__m256i *)msg); - FROMMSG64(0); - FROMMSG64(1); - FROMMSG64(2); - FROMMSG64(3); -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *msg: pointer to output message -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly * restrict a) -{ - unsigned int i; - uint32_t small; - __m256i f0, f1, g0, g1; - const __m256i hq = _mm256_set1_epi16((KYBER_Q - 1)/2); - const __m256i hhq = _mm256_set1_epi16((KYBER_Q - 1)/4); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_sub_epi16(hq, f0); - f1 = _mm256_sub_epi16(hq, f1); - g0 = _mm256_srai_epi16(f0, 15); - g1 = _mm256_srai_epi16(f1, 15); - f0 = _mm256_xor_si256(f0, g0); - f1 = _mm256_xor_si256(f1, g1); - f0 = _mm256_sub_epi16(f0, hhq); - f1 = _mm256_sub_epi16(f1, hhq); - f0 = _mm256_packs_epi16(f0, f1); - f0 = _mm256_permute4x64_epi64(f0, 0xD8); - small = _mm256_movemask_epi8(f0); - memcpy(&msg[4*i], &small, 4); - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA1*KYBER_N/4+32) buf; // +32 bytes as required by poly_cbd_eta1 - prf(buf.coeffs, KYBER_ETA1*KYBER_N/4, seed, nonce); - poly_cbd_eta1(r, buf.vec); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA2*KYBER_N/4) buf; - prf(buf.coeffs, KYBER_ETA2*KYBER_N/4, seed, nonce); - poly_cbd_eta2(r, buf.vec); -} - -#ifndef KYBER_90S -#define NOISE_NBLOCKS ((KYBER_ETA1*KYBER_N/4+SHAKE256_RATE-1)/SHAKE256_RATE) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta1(r2, buf[2].vec); - poly_cbd_eta1(r3, buf[3].vec); -} - -#if KYBER_K == 2 -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta2(r2, buf[2].vec); - poly_cbd_eta2(r3, buf[3].vec); -} -#endif -#endif - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place. -* Input coefficients assumed to be in normal order, -* output coefficients are in special order that is natural -* for the vectorization. Input coefficients are assumed to be -* bounded by q in absolute value, output coefficients are bounded -* by 16118 in absolute value. -* -* Arguments: - poly *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* Input coefficients assumed to be in special order from vectorized -* forward ntt, output in normal order. Input coefficients can be -* arbitrary 16-bit integers, output coefficients are bounded by 14870 -* in absolute value. -* -* Arguments: - poly *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt_avx(r->vec, qdata.vec); -} - -void poly_nttunpack(poly *r) -{ - nttunpack_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain. -* One of the input polynomials needs to have coefficients -* bounded by q, the other polynomial can have arbitrary -* coefficients. Output coefficients are bounded by 6656. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - basemul_avx(r->vec, a->vec, b->vec, qdata.vec); -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - tomont_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - reduce_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_add_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_sub_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.h deleted file mode 100644 index 6a9cf71c70..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/poly.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "align.h" -#include "params.h" - -typedef ALIGNED_INT16(KYBER_N) poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#ifndef KYBER_90S -#define poly_getnoise_eta1_4x KYBER_NAMESPACE(poly_getnoise_eta2_4x) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); - -#if KYBER_K == 2 -#define poly_getnoise_eta1122_4x KYBER_NAMESPACE(poly_getnoise_eta1122_4x) -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); -#endif -#endif - - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_nttunpack KYBER_NAMESPACE(poly_nttunpack) -void poly_nttunpack(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.c deleted file mode 100644 index a0174b7b3f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.c +++ /dev/null @@ -1,307 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) -static void poly_compress10(uint8_t r[320], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(15); - const __m256i shift1 = _mm256_set1_epi16(1 << 12); - const __m256i mask = _mm256_set1_epi16(1023); - const __m256i shift2 = _mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8, 4, 3, 2, 1, 0,-1,-1,-1,-1,-1,-1,12,11,10, 9, - -1,-1,-1,-1,-1,-1,12,11,10, 9, 8, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srli_epi64(f0,12); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blend_epi16(t0,t1,0xE0); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -static void poly_decompress10(poly * restrict r, const uint8_t a[320+12]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_set1_epi32((KYBER_Q << 16) + 4*KYBER_Q); - const __m256i shufbidx = _mm256_set_epi8(11,10,10, 9, 9, 8, 8, 7, - 6, 5, 5, 4, 4, 3, 3, 2, - 9, 8, 8, 7, 7, 6, 6, 5, - 4, 3, 3, 2, 2, 1, 1, 0); - const __m256i sllvdidx = _mm256_set1_epi64x(4); - const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) -static void poly_compress11(uint8_t r[352+2], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(36); - const __m256i shift1 = _mm256_set1_epi16(1 << 13); - const __m256i mask = _mm256_set1_epi16(2047); - const __m256i shift2 = _mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(10); - const __m256i srlvqidx = _mm256_set_epi64x(30,10,30,10); - const __m256i shufbidx = _mm256_set_epi8( 4, 3, 2, 1, 0, 0,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, - -1,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f1 = _mm256_bsrli_epi128(f0,8); - f0 = _mm256_srlv_epi64(f0,srlvqidx); - f1 = _mm256_slli_epi64(f1,34); - f0 = _mm256_add_epi64(f0,f1); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[22*i+ 0],t0); - _mm_storel_epi64((__m128i *)&r[22*i+16],t1); - } -} - -static void poly_decompress11(poly * restrict r, const uint8_t a[352+10]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(13,12,12,11,10, 9, 9, 8, - 8, 7, 6, 5, 5, 4, 4, 3, - 10, 9, 9, 8, 7, 6, 6, 5, - 5, 4, 3, 2, 2, 1, 1, 0); - const __m256i srlvdidx = _mm256_set_epi32(0,0,1,0,0,0,1,0); - const __m256i srlvqidx = _mm256_set_epi64x(2,0,2,0); - const __m256i shift = _mm256_set_epi16(4,32,1,8,32,1,4,32,4,32,1,8,32,1,4,32); - const __m256i mask = _mm256_set1_epi16(32752); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i]); -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i],&a[320*i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i],&a[352*i]); -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements in a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly tmp; - - poly_basemul_montgomery(r,&a->vec[0],&b->vec[0]); - for(i=1;ivec[i],&b->vec[i]); - poly_add(r,r,&tmp); - } -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.h deleted file mode 100644 index 2ce23c31ff..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/reduce.h deleted file mode 100644 index 5368185b5f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/reduce.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include "params.h" -#include - -#define reduce_avx KYBER_NAMESPACE(reduce_avx) -void reduce_avx(__m256i *r, const __m256i *qdata); -#define tomont_avx KYBER_NAMESPACE(tomont_avx) -void tomont_avx(__m256i *r, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.c deleted file mode 100644 index 9060a44cb9..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.c +++ /dev/null @@ -1,398 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "consts.h" -#include "rejsample.h" - -//#define BMI - -#ifndef BMI -static const uint8_t idx[256][8] = { - {-1, -1, -1, -1, -1, -1, -1, -1}, - { 0, -1, -1, -1, -1, -1, -1, -1}, - { 2, -1, -1, -1, -1, -1, -1, -1}, - { 0, 2, -1, -1, -1, -1, -1, -1}, - { 4, -1, -1, -1, -1, -1, -1, -1}, - { 0, 4, -1, -1, -1, -1, -1, -1}, - { 2, 4, -1, -1, -1, -1, -1, -1}, - { 0, 2, 4, -1, -1, -1, -1, -1}, - { 6, -1, -1, -1, -1, -1, -1, -1}, - { 0, 6, -1, -1, -1, -1, -1, -1}, - { 2, 6, -1, -1, -1, -1, -1, -1}, - { 0, 2, 6, -1, -1, -1, -1, -1}, - { 4, 6, -1, -1, -1, -1, -1, -1}, - { 0, 4, 6, -1, -1, -1, -1, -1}, - { 2, 4, 6, -1, -1, -1, -1, -1}, - { 0, 2, 4, 6, -1, -1, -1, -1}, - { 8, -1, -1, -1, -1, -1, -1, -1}, - { 0, 8, -1, -1, -1, -1, -1, -1}, - { 2, 8, -1, -1, -1, -1, -1, -1}, - { 0, 2, 8, -1, -1, -1, -1, -1}, - { 4, 8, -1, -1, -1, -1, -1, -1}, - { 0, 4, 8, -1, -1, -1, -1, -1}, - { 2, 4, 8, -1, -1, -1, -1, -1}, - { 0, 2, 4, 8, -1, -1, -1, -1}, - { 6, 8, -1, -1, -1, -1, -1, -1}, - { 0, 6, 8, -1, -1, -1, -1, -1}, - { 2, 6, 8, -1, -1, -1, -1, -1}, - { 0, 2, 6, 8, -1, -1, -1, -1}, - { 4, 6, 8, -1, -1, -1, -1, -1}, - { 0, 4, 6, 8, -1, -1, -1, -1}, - { 2, 4, 6, 8, -1, -1, -1, -1}, - { 0, 2, 4, 6, 8, -1, -1, -1}, - {10, -1, -1, -1, -1, -1, -1, -1}, - { 0, 10, -1, -1, -1, -1, -1, -1}, - { 2, 10, -1, -1, -1, -1, -1, -1}, - { 0, 2, 10, -1, -1, -1, -1, -1}, - { 4, 10, -1, -1, -1, -1, -1, -1}, - { 0, 4, 10, -1, -1, -1, -1, -1}, - { 2, 4, 10, -1, -1, -1, -1, -1}, - { 0, 2, 4, 10, -1, -1, -1, -1}, - { 6, 10, -1, -1, -1, -1, -1, -1}, - { 0, 6, 10, -1, -1, -1, -1, -1}, - { 2, 6, 10, -1, -1, -1, -1, -1}, - { 0, 2, 6, 10, -1, -1, -1, -1}, - { 4, 6, 10, -1, -1, -1, -1, -1}, - { 0, 4, 6, 10, -1, -1, -1, -1}, - { 2, 4, 6, 10, -1, -1, -1, -1}, - { 0, 2, 4, 6, 10, -1, -1, -1}, - { 8, 10, -1, -1, -1, -1, -1, -1}, - { 0, 8, 10, -1, -1, -1, -1, -1}, - { 2, 8, 10, -1, -1, -1, -1, -1}, - { 0, 2, 8, 10, -1, -1, -1, -1}, - { 4, 8, 10, -1, -1, -1, -1, -1}, - { 0, 4, 8, 10, -1, -1, -1, -1}, - { 2, 4, 8, 10, -1, -1, -1, -1}, - { 0, 2, 4, 8, 10, -1, -1, -1}, - { 6, 8, 10, -1, -1, -1, -1, -1}, - { 0, 6, 8, 10, -1, -1, -1, -1}, - { 2, 6, 8, 10, -1, -1, -1, -1}, - { 0, 2, 6, 8, 10, -1, -1, -1}, - { 4, 6, 8, 10, -1, -1, -1, -1}, - { 0, 4, 6, 8, 10, -1, -1, -1}, - { 2, 4, 6, 8, 10, -1, -1, -1}, - { 0, 2, 4, 6, 8, 10, -1, -1}, - {12, -1, -1, -1, -1, -1, -1, -1}, - { 0, 12, -1, -1, -1, -1, -1, -1}, - { 2, 12, -1, -1, -1, -1, -1, -1}, - { 0, 2, 12, -1, -1, -1, -1, -1}, - { 4, 12, -1, -1, -1, -1, -1, -1}, - { 0, 4, 12, -1, -1, -1, -1, -1}, - { 2, 4, 12, -1, -1, -1, -1, -1}, - { 0, 2, 4, 12, -1, -1, -1, -1}, - { 6, 12, -1, -1, -1, -1, -1, -1}, - { 0, 6, 12, -1, -1, -1, -1, -1}, - { 2, 6, 12, -1, -1, -1, -1, -1}, - { 0, 2, 6, 12, -1, -1, -1, -1}, - { 4, 6, 12, -1, -1, -1, -1, -1}, - { 0, 4, 6, 12, -1, -1, -1, -1}, - { 2, 4, 6, 12, -1, -1, -1, -1}, - { 0, 2, 4, 6, 12, -1, -1, -1}, - { 8, 12, -1, -1, -1, -1, -1, -1}, - { 0, 8, 12, -1, -1, -1, -1, -1}, - { 2, 8, 12, -1, -1, -1, -1, -1}, - { 0, 2, 8, 12, -1, -1, -1, -1}, - { 4, 8, 12, -1, -1, -1, -1, -1}, - { 0, 4, 8, 12, -1, -1, -1, -1}, - { 2, 4, 8, 12, -1, -1, -1, -1}, - { 0, 2, 4, 8, 12, -1, -1, -1}, - { 6, 8, 12, -1, -1, -1, -1, -1}, - { 0, 6, 8, 12, -1, -1, -1, -1}, - { 2, 6, 8, 12, -1, -1, -1, -1}, - { 0, 2, 6, 8, 12, -1, -1, -1}, - { 4, 6, 8, 12, -1, -1, -1, -1}, - { 0, 4, 6, 8, 12, -1, -1, -1}, - { 2, 4, 6, 8, 12, -1, -1, -1}, - { 0, 2, 4, 6, 8, 12, -1, -1}, - {10, 12, -1, -1, -1, -1, -1, -1}, - { 0, 10, 12, -1, -1, -1, -1, -1}, - { 2, 10, 12, -1, -1, -1, -1, -1}, - { 0, 2, 10, 12, -1, -1, -1, -1}, - { 4, 10, 12, -1, -1, -1, -1, -1}, - { 0, 4, 10, 12, -1, -1, -1, -1}, - { 2, 4, 10, 12, -1, -1, -1, -1}, - { 0, 2, 4, 10, 12, -1, -1, -1}, - { 6, 10, 12, -1, -1, -1, -1, -1}, - { 0, 6, 10, 12, -1, -1, -1, -1}, - { 2, 6, 10, 12, -1, -1, -1, -1}, - { 0, 2, 6, 10, 12, -1, -1, -1}, - { 4, 6, 10, 12, -1, -1, -1, -1}, - { 0, 4, 6, 10, 12, -1, -1, -1}, - { 2, 4, 6, 10, 12, -1, -1, -1}, - { 0, 2, 4, 6, 10, 12, -1, -1}, - { 8, 10, 12, -1, -1, -1, -1, -1}, - { 0, 8, 10, 12, -1, -1, -1, -1}, - { 2, 8, 10, 12, -1, -1, -1, -1}, - { 0, 2, 8, 10, 12, -1, -1, -1}, - { 4, 8, 10, 12, -1, -1, -1, -1}, - { 0, 4, 8, 10, 12, -1, -1, -1}, - { 2, 4, 8, 10, 12, -1, -1, -1}, - { 0, 2, 4, 8, 10, 12, -1, -1}, - { 6, 8, 10, 12, -1, -1, -1, -1}, - { 0, 6, 8, 10, 12, -1, -1, -1}, - { 2, 6, 8, 10, 12, -1, -1, -1}, - { 0, 2, 6, 8, 10, 12, -1, -1}, - { 4, 6, 8, 10, 12, -1, -1, -1}, - { 0, 4, 6, 8, 10, 12, -1, -1}, - { 2, 4, 6, 8, 10, 12, -1, -1}, - { 0, 2, 4, 6, 8, 10, 12, -1}, - {14, -1, -1, -1, -1, -1, -1, -1}, - { 0, 14, -1, -1, -1, -1, -1, -1}, - { 2, 14, -1, -1, -1, -1, -1, -1}, - { 0, 2, 14, -1, -1, -1, -1, -1}, - { 4, 14, -1, -1, -1, -1, -1, -1}, - { 0, 4, 14, -1, -1, -1, -1, -1}, - { 2, 4, 14, -1, -1, -1, -1, -1}, - { 0, 2, 4, 14, -1, -1, -1, -1}, - { 6, 14, -1, -1, -1, -1, -1, -1}, - { 0, 6, 14, -1, -1, -1, -1, -1}, - { 2, 6, 14, -1, -1, -1, -1, -1}, - { 0, 2, 6, 14, -1, -1, -1, -1}, - { 4, 6, 14, -1, -1, -1, -1, -1}, - { 0, 4, 6, 14, -1, -1, -1, -1}, - { 2, 4, 6, 14, -1, -1, -1, -1}, - { 0, 2, 4, 6, 14, -1, -1, -1}, - { 8, 14, -1, -1, -1, -1, -1, -1}, - { 0, 8, 14, -1, -1, -1, -1, -1}, - { 2, 8, 14, -1, -1, -1, -1, -1}, - { 0, 2, 8, 14, -1, -1, -1, -1}, - { 4, 8, 14, -1, -1, -1, -1, -1}, - { 0, 4, 8, 14, -1, -1, -1, -1}, - { 2, 4, 8, 14, -1, -1, -1, -1}, - { 0, 2, 4, 8, 14, -1, -1, -1}, - { 6, 8, 14, -1, -1, -1, -1, -1}, - { 0, 6, 8, 14, -1, -1, -1, -1}, - { 2, 6, 8, 14, -1, -1, -1, -1}, - { 0, 2, 6, 8, 14, -1, -1, -1}, - { 4, 6, 8, 14, -1, -1, -1, -1}, - { 0, 4, 6, 8, 14, -1, -1, -1}, - { 2, 4, 6, 8, 14, -1, -1, -1}, - { 0, 2, 4, 6, 8, 14, -1, -1}, - {10, 14, -1, -1, -1, -1, -1, -1}, - { 0, 10, 14, -1, -1, -1, -1, -1}, - { 2, 10, 14, -1, -1, -1, -1, -1}, - { 0, 2, 10, 14, -1, -1, -1, -1}, - { 4, 10, 14, -1, -1, -1, -1, -1}, - { 0, 4, 10, 14, -1, -1, -1, -1}, - { 2, 4, 10, 14, -1, -1, -1, -1}, - { 0, 2, 4, 10, 14, -1, -1, -1}, - { 6, 10, 14, -1, -1, -1, -1, -1}, - { 0, 6, 10, 14, -1, -1, -1, -1}, - { 2, 6, 10, 14, -1, -1, -1, -1}, - { 0, 2, 6, 10, 14, -1, -1, -1}, - { 4, 6, 10, 14, -1, -1, -1, -1}, - { 0, 4, 6, 10, 14, -1, -1, -1}, - { 2, 4, 6, 10, 14, -1, -1, -1}, - { 0, 2, 4, 6, 10, 14, -1, -1}, - { 8, 10, 14, -1, -1, -1, -1, -1}, - { 0, 8, 10, 14, -1, -1, -1, -1}, - { 2, 8, 10, 14, -1, -1, -1, -1}, - { 0, 2, 8, 10, 14, -1, -1, -1}, - { 4, 8, 10, 14, -1, -1, -1, -1}, - { 0, 4, 8, 10, 14, -1, -1, -1}, - { 2, 4, 8, 10, 14, -1, -1, -1}, - { 0, 2, 4, 8, 10, 14, -1, -1}, - { 6, 8, 10, 14, -1, -1, -1, -1}, - { 0, 6, 8, 10, 14, -1, -1, -1}, - { 2, 6, 8, 10, 14, -1, -1, -1}, - { 0, 2, 6, 8, 10, 14, -1, -1}, - { 4, 6, 8, 10, 14, -1, -1, -1}, - { 0, 4, 6, 8, 10, 14, -1, -1}, - { 2, 4, 6, 8, 10, 14, -1, -1}, - { 0, 2, 4, 6, 8, 10, 14, -1}, - {12, 14, -1, -1, -1, -1, -1, -1}, - { 0, 12, 14, -1, -1, -1, -1, -1}, - { 2, 12, 14, -1, -1, -1, -1, -1}, - { 0, 2, 12, 14, -1, -1, -1, -1}, - { 4, 12, 14, -1, -1, -1, -1, -1}, - { 0, 4, 12, 14, -1, -1, -1, -1}, - { 2, 4, 12, 14, -1, -1, -1, -1}, - { 0, 2, 4, 12, 14, -1, -1, -1}, - { 6, 12, 14, -1, -1, -1, -1, -1}, - { 0, 6, 12, 14, -1, -1, -1, -1}, - { 2, 6, 12, 14, -1, -1, -1, -1}, - { 0, 2, 6, 12, 14, -1, -1, -1}, - { 4, 6, 12, 14, -1, -1, -1, -1}, - { 0, 4, 6, 12, 14, -1, -1, -1}, - { 2, 4, 6, 12, 14, -1, -1, -1}, - { 0, 2, 4, 6, 12, 14, -1, -1}, - { 8, 12, 14, -1, -1, -1, -1, -1}, - { 0, 8, 12, 14, -1, -1, -1, -1}, - { 2, 8, 12, 14, -1, -1, -1, -1}, - { 0, 2, 8, 12, 14, -1, -1, -1}, - { 4, 8, 12, 14, -1, -1, -1, -1}, - { 0, 4, 8, 12, 14, -1, -1, -1}, - { 2, 4, 8, 12, 14, -1, -1, -1}, - { 0, 2, 4, 8, 12, 14, -1, -1}, - { 6, 8, 12, 14, -1, -1, -1, -1}, - { 0, 6, 8, 12, 14, -1, -1, -1}, - { 2, 6, 8, 12, 14, -1, -1, -1}, - { 0, 2, 6, 8, 12, 14, -1, -1}, - { 4, 6, 8, 12, 14, -1, -1, -1}, - { 0, 4, 6, 8, 12, 14, -1, -1}, - { 2, 4, 6, 8, 12, 14, -1, -1}, - { 0, 2, 4, 6, 8, 12, 14, -1}, - {10, 12, 14, -1, -1, -1, -1, -1}, - { 0, 10, 12, 14, -1, -1, -1, -1}, - { 2, 10, 12, 14, -1, -1, -1, -1}, - { 0, 2, 10, 12, 14, -1, -1, -1}, - { 4, 10, 12, 14, -1, -1, -1, -1}, - { 0, 4, 10, 12, 14, -1, -1, -1}, - { 2, 4, 10, 12, 14, -1, -1, -1}, - { 0, 2, 4, 10, 12, 14, -1, -1}, - { 6, 10, 12, 14, -1, -1, -1, -1}, - { 0, 6, 10, 12, 14, -1, -1, -1}, - { 2, 6, 10, 12, 14, -1, -1, -1}, - { 0, 2, 6, 10, 12, 14, -1, -1}, - { 4, 6, 10, 12, 14, -1, -1, -1}, - { 0, 4, 6, 10, 12, 14, -1, -1}, - { 2, 4, 6, 10, 12, 14, -1, -1}, - { 0, 2, 4, 6, 10, 12, 14, -1}, - { 8, 10, 12, 14, -1, -1, -1, -1}, - { 0, 8, 10, 12, 14, -1, -1, -1}, - { 2, 8, 10, 12, 14, -1, -1, -1}, - { 0, 2, 8, 10, 12, 14, -1, -1}, - { 4, 8, 10, 12, 14, -1, -1, -1}, - { 0, 4, 8, 10, 12, 14, -1, -1}, - { 2, 4, 8, 10, 12, 14, -1, -1}, - { 0, 2, 4, 8, 10, 12, 14, -1}, - { 6, 8, 10, 12, 14, -1, -1, -1}, - { 0, 6, 8, 10, 12, 14, -1, -1}, - { 2, 6, 8, 10, 12, 14, -1, -1}, - { 0, 2, 6, 8, 10, 12, 14, -1}, - { 4, 6, 8, 10, 12, 14, -1, -1}, - { 0, 4, 6, 8, 10, 12, 14, -1}, - { 2, 4, 6, 8, 10, 12, 14, -1}, - { 0, 2, 4, 6, 8, 10, 12, 14} -}; -#endif - -#define _mm256_cmpge_epu16(a, b) _mm256_cmpeq_epi16(_mm256_max_epu16(a, b), a) -#define _mm_cmpge_epu16(a, b) _mm_cmpeq_epi16(_mm_max_epu16(a, b), a) - -unsigned int rej_uniform_avx(int16_t * restrict r, const uint8_t *buf) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - uint32_t good; -#ifdef BMI - uint64_t idx0, idx1, idx2, idx3; -#endif - const __m256i bound = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i ones = _mm256_set1_epi8(1); - const __m256i mask = _mm256_set1_epi16(0xFFF); - const __m256i idx8 = _mm256_set_epi8(15,14,14,13,12,11,11,10, - 9, 8, 8, 7, 6, 5, 5, 4, - 11,10,10, 9, 8, 7, 7, 6, - 5, 4, 4, 3, 2, 1, 1, 0); - __m256i f0, f1, g0, g1, g2, g3; - __m128i f, t, pilo, pihi; - - ctr = pos = 0; - while(ctr <= KYBER_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 56) { - f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); - f1 = _mm256_loadu_si256((__m256i *)&buf[pos+24]); - f0 = _mm256_permute4x64_epi64(f0, 0x94); - f1 = _mm256_permute4x64_epi64(f1, 0x94); - f0 = _mm256_shuffle_epi8(f0, idx8); - f1 = _mm256_shuffle_epi8(f1, idx8); - g0 = _mm256_srli_epi16(f0, 4); - g1 = _mm256_srli_epi16(f1, 4); - f0 = _mm256_blend_epi16(f0, g0, 0xAA); - f1 = _mm256_blend_epi16(f1, g1, 0xAA); - f0 = _mm256_and_si256(f0, mask); - f1 = _mm256_and_si256(f1, mask); - pos += 48; - - g0 = _mm256_cmpgt_epi16(bound, f0); - g1 = _mm256_cmpgt_epi16(bound, f1); - - g0 = _mm256_packs_epi16(g0, g1); - good = _mm256_movemask_epi8(g0); - -#ifdef BMI - idx0 = _pdep_u64(good >> 0, 0x0101010101010101); - idx1 = _pdep_u64(good >> 8, 0x0101010101010101); - idx2 = _pdep_u64(good >> 16, 0x0101010101010101); - idx3 = _pdep_u64(good >> 24, 0x0101010101010101); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - idx1 = (idx1 << 8) - idx1; - idx1 = _pext_u64(0x0E0C0A0806040200, idx1); - idx2 = (idx2 << 8) - idx2; - idx2 = _pext_u64(0x0E0C0A0806040200, idx2); - idx3 = (idx3 << 8) - idx3; - idx3 = _pext_u64(0x0E0C0A0806040200, idx3); - - g0 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx0)); - g1 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx1)); - g0 = _mm256_inserti128_si256(g0, _mm_cvtsi64_si128(idx2), 1); - g1 = _mm256_inserti128_si256(g1, _mm_cvtsi64_si128(idx3), 1); -#else - g0 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 0) & 0xFF])); - g1 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 8) & 0xFF])); - g0 = _mm256_inserti128_si256(g0, _mm_loadl_epi64((__m128i *)&idx[(good >> 16) & 0xFF]), 1); - g1 = _mm256_inserti128_si256(g1, _mm_loadl_epi64((__m128i *)&idx[(good >> 24) & 0xFF]), 1); -#endif - - g2 = _mm256_add_epi8(g0, ones); - g3 = _mm256_add_epi8(g1, ones); - g0 = _mm256_unpacklo_epi8(g0, g2); - g1 = _mm256_unpacklo_epi8(g1, g3); - - f0 = _mm256_shuffle_epi8(f0, g0); - f1 = _mm256_shuffle_epi8(f1, g1); - - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); - ctr += _mm_popcnt_u32((good >> 0) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); - ctr += _mm_popcnt_u32((good >> 16) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); - ctr += _mm_popcnt_u32((good >> 8) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); - ctr += _mm_popcnt_u32((good >> 24) & 0xFF); - } - - while(ctr <= KYBER_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 16) { - f = _mm_loadu_si128((__m128i *)&buf[pos]); - f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); - t = _mm_srli_epi16(f, 4); - f = _mm_blend_epi16(f, t, 0xAA); - f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); - pos += 12; - - t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); - good = _mm_movemask_epi8(t); - -#ifdef BMI - good &= 0x5555; - idx0 = _pdep_u64(good, 0x1111111111111111); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - pilo = _mm_cvtsi64_si128(idx0); -#else - good = _pext_u32(good, 0x5555); - pilo = _mm_loadl_epi64((__m128i *)&idx[good]); -#endif - - pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); - pilo = _mm_unpacklo_epi8(pilo, pihi); - f = _mm_shuffle_epi8(f, pilo); - _mm_storeu_si128((__m128i *)&r[ctr], f); - ctr += _mm_popcnt_u32(good); - } - - while(ctr < KYBER_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)); - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(val1 < KYBER_Q && ctr < KYBER_N) - r[ctr++] = val1; - } - - return ctr; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.h deleted file mode 100644 index 3be5e2192e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/rejsample.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef REJSAMPLE_H -#define REJSAMPLE_H - -#include -#include "params.h" -#include "symmetric.h" - -#define REJ_UNIFORM_AVX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -#define REJ_UNIFORM_AVX_BUFLEN (REJ_UNIFORM_AVX_NBLOCKS*XOF_BLOCKBYTES) - -#define rej_uniform_avx KYBER_NAMESPACE(rej_uniform_avx) -unsigned int rej_uniform_avx(int16_t *r, const uint8_t *buf); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric.h deleted file mode 100644 index e4941f7a86..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/symmetric.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" -#include "fips202x4.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.c deleted file mode 100644 index 06243b837f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_avx2/verify.c +++ /dev/null @@ -1,83 +0,0 @@ -#include -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint64_t r; - __m256i f, g, h; - - h = _mm256_setzero_si256(); - for(i=0;i> 63; - return r; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t * restrict r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - __m256i xvec, rvec, bvec; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - bvec = _mm256_set1_epi64x(-(uint64_t)b); - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/api.h deleted file mode 100644 index 70d40f3f3e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_ref_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_ref_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_ref_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_ref_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_ref_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_ref_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_ref_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_ref_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_ref_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_ref_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_ref_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_ref_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_ref_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_ref_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_ref_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_ref_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_ref_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_ref_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.c deleted file mode 100644 index 1500ffea56..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.c +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: load32_littleendian -* -* Description: load 4 bytes into a 32-bit integer -* in little-endian order -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x -**************************************************/ -static uint32_t load32_littleendian(const uint8_t x[4]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - r |= (uint32_t)x[3] << 24; - return r; -} - -/************************************************* -* Name: load24_littleendian -* -* Description: load 3 bytes into a 32-bit integer -* in little-endian order. -* This function is only needed for Kyber-512 -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x (most significant byte is zero) -**************************************************/ -#if KYBER_ETA1 == 3 -static uint32_t load24_littleendian(const uint8_t x[3]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - return r; -} -#endif - - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -static void cbd2(poly *r, const uint8_t buf[2*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x55555555; - - for(j=0;j<8;j++) { - a = (d >> (4*j+0)) & 0x3; - b = (d >> (4*j+2)) & 0x3; - r->coeffs[8*i+j] = a - b; - } - } -} - -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3. -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -#if KYBER_ETA1 == 3 -static void cbd3(poly *r, const uint8_t buf[3*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x00249249; - d += (t>>2) & 0x00249249; - - for(j=0;j<4;j++) { - a = (d >> (6*j+0)) & 0x7; - b = (d >> (6*j+3)) & 0x7; - r->coeffs[4*i+j] = a - b; - } - } -} -#endif - -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.h deleted file mode 100644 index 7b677d745d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/cbd.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.c deleted file mode 100644 index 726cfa985d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/indcpa.c +++ /dev/null @@ -1,334 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk -* and the public seed used to generate the matrix A. -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output buffer -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos + 3 <= buflen) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if(XOF_BLOCKBYTES % 3) -#error "Implementation of gen_matrix assumes that XOF_BLOCKBYTES is a multiple of 3" -#endif - -#define GEN_MATRIX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -// Not static for benchmarking -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed) -{ - unsigned int ctr, i, j; - unsigned int buflen; - uint8_t buf[GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]; - xof_state state; - xof_init(&state, seed); - - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.c deleted file mode 100644 index 2f2eb10b2f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.c +++ /dev/null @@ -1,146 +0,0 @@ -#include -#include "params.h" -#include "ntt.h" -#include "reduce.h" - -/* Code to generate zetas and zetas_inv used in the number-theoretic transform: - -#define KYBER_ROOT_OF_UNITY 17 - -static const uint8_t tree[128] = { - 0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, - 4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124, - 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122, - 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, - 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, - 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, - 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, - 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127 -}; - -void init_ntt() { - unsigned int i; - int16_t tmp[128]; - - tmp[0] = MONT; - for(i=1;i<128;i++) - tmp[i] = fqmul(tmp[i-1],MONT*KYBER_ROOT_OF_UNITY % KYBER_Q); - - for(i=0;i<128;i++) { - zetas[i] = tmp[tree[i]]; - if(zetas[i] > KYBER_Q/2) - zetas[i] -= KYBER_Q; - if(zetas[i] < -KYBER_Q/2) - zetas[i] += KYBER_Q; - } -} -*/ - -const int16_t zetas[128] = { - -1044, -758, -359, -1517, 1493, 1422, 287, 202, - -171, 622, 1577, 182, 962, -1202, -1474, 1468, - 573, -1325, 264, 383, -829, 1458, -1602, -130, - -681, 1017, 732, 608, -1542, 411, -205, -1571, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -1103, 430, 555, 843, -1251, 871, 1550, 105, - 422, 587, 177, -235, -291, -460, 1574, 1653, - -246, 778, 1159, -147, -777, 1483, -602, 1119, - -1590, 644, -872, 349, 418, 329, -156, -75, - 817, 1097, 603, 610, 1322, -1285, -1465, 384, - -1215, -136, 1218, -1335, -874, 220, -1187, -1659, - -1185, -1530, -1278, 794, -1510, -854, -870, 478, - -108, -308, 996, 991, 958, -1460, 1522, 1628 -}; - -/************************************************* -* Name: fqmul -* -* Description: Multiplication followed by Montgomery reduction -* -* Arguments: - int16_t a: first factor -* - int16_t b: second factor -* -* Returns 16-bit integer congruent to a*b*R^{-1} mod q -**************************************************/ -static int16_t fqmul(int16_t a, int16_t b) { - return montgomery_reduce((int32_t)a*b); -} - -/************************************************* -* Name: ntt -* -* Description: Inplace number-theoretic transform (NTT) in Rq. -* input is in standard order, output is in bitreversed order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void ntt(int16_t r[256]) { - unsigned int len, start, j, k; - int16_t t, zeta; - - k = 1; - for(len = 128; len >= 2; len >>= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k++]; - for(j = start; j < start + len; j++) { - t = fqmul(zeta, r[j + len]); - r[j + len] = r[j] - t; - r[j] = r[j] + t; - } - } - } -} - -/************************************************* -* Name: invntt_tomont -* -* Description: Inplace inverse number-theoretic transform in Rq and -* multiplication by Montgomery factor 2^16. -* Input is in bitreversed order, output is in standard order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void invntt(int16_t r[256]) { - unsigned int start, len, j, k; - int16_t t, zeta; - const int16_t f = 1441; // mont^2/128 - - k = 127; - for(len = 2; len <= 128; len <<= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k--]; - for(j = start; j < start + len; j++) { - t = r[j]; - r[j] = barrett_reduce(t + r[j + len]); - r[j + len] = r[j + len] - t; - r[j + len] = fqmul(zeta, r[j + len]); - } - } - } - - for(j = 0; j < 256; j++) - r[j] = fqmul(r[j], f); -} - -/************************************************* -* Name: basemul -* -* Description: Multiplication of polynomials in Zq[X]/(X^2-zeta) -* used for multiplication of elements in Rq in NTT domain -* -* Arguments: - int16_t r[2]: pointer to the output polynomial -* - const int16_t a[2]: pointer to the first factor -* - const int16_t b[2]: pointer to the second factor -* - int16_t zeta: integer defining the reduction polynomial -**************************************************/ -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta) -{ - r[0] = fqmul(a[1], b[1]); - r[0] = fqmul(r[0], zeta); - r[0] += fqmul(a[0], b[0]); - r[1] = fqmul(a[0], b[1]); - r[1] += fqmul(a[1], b[0]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.h deleted file mode 100644 index 227ea74f08..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/ntt.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include "params.h" - -#define zetas KYBER_NAMESPACE(zetas) -extern const int16_t zetas[128]; - -#define ntt KYBER_NAMESPACE(ntt) -void ntt(int16_t poly[256]); - -#define invntt KYBER_NAMESPACE(invntt) -void invntt(int16_t poly[256]); - -#define basemul KYBER_NAMESPACE(basemul) -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/params.h deleted file mode 100644 index fb4190b311..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/params.h +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_ref_##s -#elif (KYBER_K == 3) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_ref_##s -#elif (KYBER_K == 4) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_ref_##s -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.c deleted file mode 100644 index cbd3abfb54..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.c +++ /dev/null @@ -1,360 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" -#include "verify.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a) -{ - unsigned int i,j; - int16_t u; - uint32_t d0; - uint8_t t[8]; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 15; */ - d0 = u << 4; - d0 += 1665; - d0 *= 80635; - d0 >>= 28; - t[j] = d0 & 0xf; - } - - r[0] = t[0] | (t[1] << 4); - r[1] = t[2] | (t[3] << 4); - r[2] = t[4] | (t[5] << 4); - r[3] = t[6] | (t[7] << 4); - r += 4; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 31; */ - d0 = u << 5; - d0 += 1664; - d0 *= 40318; - d0 >>= 27; - t[j] = d0 & 0x1f; - } - - r[0] = (t[0] >> 0) | (t[1] << 5); - r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7); - r[2] = (t[3] >> 1) | (t[4] << 4); - r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6); - r[4] = (t[6] >> 2) | (t[7] << 3); - r += 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_decompress -* -* Description: De-serialization and subsequent decompression of a polynomial; -* approximate inverse of poly_compress -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYCOMPRESSEDBYTES bytes) -**************************************************/ -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]) -{ - unsigned int i; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - for(i=0;icoeffs[2*i+0] = (((uint16_t)(a[0] & 15)*KYBER_Q) + 8) >> 4; - r->coeffs[2*i+1] = (((uint16_t)(a[0] >> 4)*KYBER_Q) + 8) >> 4; - a += 1; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - unsigned int j; - uint8_t t[8]; - for(i=0;i> 0); - t[1] = (a[0] >> 5) | (a[1] << 3); - t[2] = (a[1] >> 2); - t[3] = (a[1] >> 7) | (a[2] << 1); - t[4] = (a[2] >> 4) | (a[3] << 4); - t[5] = (a[3] >> 1); - t[6] = (a[3] >> 6) | (a[4] << 2); - t[7] = (a[4] >> 3); - a += 5; - - for(j=0;j<8;j++) - r->coeffs[8*i+j] = ((uint32_t)(t[j] & 31)*KYBER_Q + 16) >> 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - unsigned int i; - uint16_t t0, t1; - - for(i=0;icoeffs[2*i]; - t0 += ((int16_t)t0 >> 15) & KYBER_Q; - t1 = a->coeffs[2*i+1]; - t1 += ((int16_t)t1 >> 15) & KYBER_Q; - r[3*i+0] = (t0 >> 0); - r[3*i+1] = (t0 >> 8) | (t1 << 4); - r[3*i+2] = (t1 >> 4); - } -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - unsigned int i; - for(i=0;icoeffs[2*i] = ((a[3*i+0] >> 0) | ((uint16_t)a[3*i+1] << 8)) & 0xFFF; - r->coeffs[2*i+1] = ((a[3*i+1] >> 4) | ((uint16_t)a[3*i+2] << 4)) & 0xFFF; - } -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ - unsigned int i,j; - -#if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) -#error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" -#endif - - for(i=0;icoeffs[8*i+j] = 0; - cmov_int16(r->coeffs+8*i+j, ((KYBER_Q+1)/2), (msg[i] >> j)&1); - } - } -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message -* -* Arguments: - uint8_t *msg: pointer to output message -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) -{ - unsigned int i,j; - uint32_t t; - - for(i=0;icoeffs[8*i+j]; - // t += ((int16_t)t >> 15) & KYBER_Q; - // t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1; - t <<= 1; - t += 1665; - t *= 80635; - t >>= 28; - t &= 1; - msg[i] |= t << j; - } - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA1*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta1(r, buf); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA2*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta2(r, buf); -} - - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place; -* inputs assumed to be in normal order, output in bitreversed order -* -* Arguments: - uint16_t *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt(r->coeffs); - poly_reduce(r); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* inputs assumed to be in bitreversed order, output in normal order -* -* Arguments: - uint16_t *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt(r->coeffs); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[4*i], &a->coeffs[4*i], &b->coeffs[4*i], zetas[64+i]); - basemul(&r->coeffs[4*i+2], &a->coeffs[4*i+2], &b->coeffs[4*i+2], -zetas[64+i]); - } -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - unsigned int i; - const int16_t f = (1ULL << 32) % KYBER_Q; - for(i=0;icoeffs[i] = montgomery_reduce((int32_t)r->coeffs[i]*f); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - unsigned int i; - for(i=0;icoeffs[i] = barrett_reduce(r->coeffs[i]); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] + b->coeffs[i]; -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] - b->coeffs[i]; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.h deleted file mode 100644 index 9a99c7cdad..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/poly.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "params.h" - -/* - * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial - * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] - */ -typedef struct{ - int16_t coeffs[KYBER_N]; -} poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.c deleted file mode 100644 index 669f6a5f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.c +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "polyvec.h" - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a) -{ - unsigned int i,j,k; - uint64_t d0; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;ivec[i].coeffs[8*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; */ - d0 = t[k]; - d0 <<= 11; - d0 += 1664; - d0 *= 645084; - d0 >>= 31; - t[k] = d0 & 0x7ff; - } - - r[ 0] = (t[0] >> 0); - r[ 1] = (t[0] >> 8) | (t[1] << 3); - r[ 2] = (t[1] >> 5) | (t[2] << 6); - r[ 3] = (t[2] >> 2); - r[ 4] = (t[2] >> 10) | (t[3] << 1); - r[ 5] = (t[3] >> 7) | (t[4] << 4); - r[ 6] = (t[4] >> 4) | (t[5] << 7); - r[ 7] = (t[5] >> 1); - r[ 8] = (t[5] >> 9) | (t[6] << 2); - r[ 9] = (t[6] >> 6) | (t[7] << 5); - r[10] = (t[7] >> 3); - r += 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;ivec[i].coeffs[4*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; */ - d0 = t[k]; - d0 <<= 10; - d0 += 1665; - d0 *= 1290167; - d0 >>= 32; - t[k] = d0 & 0x3ff; - } - - r[0] = (t[0] >> 0); - r[1] = (t[0] >> 8) | (t[1] << 2); - r[2] = (t[1] >> 6) | (t[2] << 4); - r[3] = (t[2] >> 4) | (t[3] << 6); - r[4] = (t[3] >> 2); - r += 5; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]) -{ - unsigned int i,j,k; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;i> 0) | ((uint16_t)a[ 1] << 8); - t[1] = (a[1] >> 3) | ((uint16_t)a[ 2] << 5); - t[2] = (a[2] >> 6) | ((uint16_t)a[ 3] << 2) | ((uint16_t)a[4] << 10); - t[3] = (a[4] >> 1) | ((uint16_t)a[ 5] << 7); - t[4] = (a[5] >> 4) | ((uint16_t)a[ 6] << 4); - t[5] = (a[6] >> 7) | ((uint16_t)a[ 7] << 1) | ((uint16_t)a[8] << 9); - t[6] = (a[8] >> 2) | ((uint16_t)a[ 9] << 6); - t[7] = (a[9] >> 5) | ((uint16_t)a[10] << 3); - a += 11; - - for(k=0;k<8;k++) - r->vec[i].coeffs[8*j+k] = ((uint32_t)(t[k] & 0x7FF)*KYBER_Q + 1024) >> 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;i> 0) | ((uint16_t)a[1] << 8); - t[1] = (a[1] >> 2) | ((uint16_t)a[2] << 6); - t[2] = (a[2] >> 4) | ((uint16_t)a[3] << 4); - t[3] = (a[3] >> 6) | ((uint16_t)a[4] << 2); - a += 5; - - for(k=0;k<4;k++) - r->vec[i].coeffs[4*j+k] = ((uint32_t)(t[k] & 0x3FF)*KYBER_Q + 512) >> 10; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements of a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly t; - - poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]); - for(i=1;ivec[i], &b->vec[i]); - poly_add(r, r, &t); - } - - poly_reduce(r); -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.h deleted file mode 100644 index 57b605494e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.c deleted file mode 100644 index 9d8e7edf83..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.c +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include "params.h" -#include "reduce.h" - -/************************************************* -* Name: montgomery_reduce -* -* Description: Montgomery reduction; given a 32-bit integer a, computes -* 16-bit integer congruent to a * R^-1 mod q, where R=2^16 -* -* Arguments: - int32_t a: input integer to be reduced; -* has to be in {-q2^15,...,q2^15-1} -* -* Returns: integer in {-q+1,...,q-1} congruent to a * R^-1 modulo q. -**************************************************/ -int16_t montgomery_reduce(int32_t a) -{ - int16_t t; - - t = (int16_t)a*QINV; - t = (a - (int32_t)t*KYBER_Q) >> 16; - return t; -} - -/************************************************* -* Name: barrett_reduce -* -* Description: Barrett reduction; given a 16-bit integer a, computes -* centered representative congruent to a mod q in {-(q-1)/2,...,(q-1)/2} -* -* Arguments: - int16_t a: input integer to be reduced -* -* Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. -**************************************************/ -int16_t barrett_reduce(int16_t a) { - int16_t t; - const int16_t v = ((1<<26) + KYBER_Q/2)/KYBER_Q; - - t = ((int32_t)v*a + (1<<25)) >> 26; - t *= KYBER_Q; - return a - t; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.h deleted file mode 100644 index c1bc1e4c7b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/reduce.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include -#include "params.h" - -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 - -#define montgomery_reduce KYBER_NAMESPACE(montgomery_reduce) -int16_t montgomery_reduce(int32_t a); - -#define barrett_reduce KYBER_NAMESPACE(barrett_reduce) -int16_t barrett_reduce(int16_t a); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric.h deleted file mode 100644 index 2acc66f98d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/symmetric.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_init(STATE, SEED) shake128_inc_init(STATE) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define xof_release(STATE) shake128_inc_ctx_release(STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/verify.c deleted file mode 100644 index 914ccd448f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-1024_ref/verify.c +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint8_t r = 0; - - for(i=0;i> 63; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - b = -b; - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/align.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/align.h deleted file mode 100644 index 3463866f37..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/align.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef ALIGN_H -#define ALIGN_H - -#include -#include - -#define ALIGNED_UINT8(N) \ - union { \ - uint8_t coeffs[N]; \ - __m256i vec[(N+31)/32]; \ - } - -#define ALIGNED_INT16(N) \ - union { \ - int16_t coeffs[N]; \ - __m256i vec[(N+15)/16]; \ - } - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/api.h deleted file mode 100644 index a154e80f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_avx2_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_avx2_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_avx2_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_avx2_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_avx2_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_avx2_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_avx2_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_avx2_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_avx2_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_avx2_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_avx2_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_avx2_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_avx2_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_avx2_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_avx2_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_avx2_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_avx2_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_avx2_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.c deleted file mode 100644 index dad473c79e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.c +++ /dev/null @@ -1,144 +0,0 @@ -#include -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd2(poly * restrict r, const __m256i buf[2*KYBER_N/128]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask55 = _mm256_set1_epi32(0x55555555); - const __m256i mask33 = _mm256_set1_epi32(0x33333333); - const __m256i mask03 = _mm256_set1_epi32(0x03030303); - const __m256i mask0F = _mm256_set1_epi32(0x0F0F0F0F); - - for(i = 0; i < KYBER_N/64; i++) { - f0 = _mm256_load_si256(&buf[i]); - - f1 = _mm256_srli_epi16(f0, 1); - f0 = _mm256_and_si256(mask55, f0); - f1 = _mm256_and_si256(mask55, f1); - f0 = _mm256_add_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 2); - f0 = _mm256_and_si256(mask33, f0); - f1 = _mm256_and_si256(mask33, f1); - f0 = _mm256_add_epi8(f0, mask33); - f0 = _mm256_sub_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 4); - f0 = _mm256_and_si256(mask0F, f0); - f1 = _mm256_and_si256(mask0F, f1); - f0 = _mm256_sub_epi8(f0, mask03); - f1 = _mm256_sub_epi8(f1, mask03); - - f2 = _mm256_unpacklo_epi8(f0, f1); - f3 = _mm256_unpackhi_epi8(f0, f1); - - f0 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f2)); - f1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f2,1)); - f2 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f3)); - f3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f3,1)); - - _mm256_store_si256(&r->vec[4*i+0], f0); - _mm256_store_si256(&r->vec[4*i+1], f2); - _mm256_store_si256(&r->vec[4*i+2], f1); - _mm256_store_si256(&r->vec[4*i+3], f3); - } -} - -#if KYBER_ETA1 == 3 -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3 -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd3(poly * restrict r, const uint8_t buf[3*KYBER_N/4+8]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask249 = _mm256_set1_epi32(0x249249); - const __m256i mask6DB = _mm256_set1_epi32(0x6DB6DB); - const __m256i mask07 = _mm256_set1_epi32(7); - const __m256i mask70 = _mm256_set1_epi32(7 << 16); - const __m256i mask3 = _mm256_set1_epi16(3); - const __m256i shufbidx = _mm256_set_epi8(-1,15,14,13,-1,12,11,10,-1, 9, 8, 7,-1, 6, 5, 4, - -1,11,10, 9,-1, 8, 7, 6,-1, 5, 4, 3,-1, 2, 1, 0); - - for(i = 0; i < KYBER_N/32; i++) { - f0 = _mm256_loadu_si256((__m256i *)&buf[24*i]); - f0 = _mm256_permute4x64_epi64(f0,0x94); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - - f1 = _mm256_srli_epi32(f0,1); - f2 = _mm256_srli_epi32(f0,2); - f0 = _mm256_and_si256(mask249,f0); - f1 = _mm256_and_si256(mask249,f1); - f2 = _mm256_and_si256(mask249,f2); - f0 = _mm256_add_epi32(f0,f1); - f0 = _mm256_add_epi32(f0,f2); - - f1 = _mm256_srli_epi32(f0,3); - f0 = _mm256_add_epi32(f0,mask6DB); - f0 = _mm256_sub_epi32(f0,f1); - - f1 = _mm256_slli_epi32(f0,10); - f2 = _mm256_srli_epi32(f0,12); - f3 = _mm256_srli_epi32(f0, 2); - f0 = _mm256_and_si256(f0,mask07); - f1 = _mm256_and_si256(f1,mask70); - f2 = _mm256_and_si256(f2,mask07); - f3 = _mm256_and_si256(f3,mask70); - f0 = _mm256_add_epi16(f0,f1); - f1 = _mm256_add_epi16(f2,f3); - f0 = _mm256_sub_epi16(f0,mask3); - f1 = _mm256_sub_epi16(f1,mask3); - - f2 = _mm256_unpacklo_epi32(f0,f1); - f3 = _mm256_unpackhi_epi32(f0,f1); - - f0 = _mm256_permute2x128_si256(f2,f3,0x20); - f1 = _mm256_permute2x128_si256(f2,f3,0x31); - - _mm256_store_si256(&r->vec[2*i+0], f0); - _mm256_store_si256(&r->vec[2*i+1], f1); - } -} -#endif - -/* buf 32 bytes longer for cbd3 */ -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, (uint8_t *)buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.h deleted file mode 100644 index 05788e06b4..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/cbd.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.c deleted file mode 100644 index 84e596893d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.c +++ /dev/null @@ -1,121 +0,0 @@ -#include "align.h" -#include "params.h" -#include "consts.h" - -#define Q KYBER_Q -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 -#define V 20159 // floor(2^26/q + 0.5) -#define FHI 1441 // mont^2/128 -#define FLO -10079 // qinv*FHI -#define MONTSQHI 1353 // mont^2 -#define MONTSQLO 20553 // qinv*MONTSQHI -#define MASK 4095 -#define SHIFT 32 - -const qdata_t qdata = {{ -#define _16XQ 0 - Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, - -#define _16XQINV 16 - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - -#define _16XV 32 - V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, - -#define _16XFLO 48 - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - -#define _16XFHI 64 - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - -#define _16XMONTSQLO 80 - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - -#define _16XMONTSQHI 96 - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - -#define _16XMASK 112 - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - -#define _REVIDXB 128 - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - -#define _REVIDXD 144 - 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0, 0, - -#define _ZETAS_EXP 160 - 31498, 31498, 31498, 31498, -758, -758, -758, -758, - 5237, 5237, 5237, 5237, 1397, 1397, 1397, 1397, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - 13525, 13525, 13525, 13525, 13525, 13525, 13525, 13525, - -12402, -12402, -12402, -12402, -12402, -12402, -12402, -12402, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - -20907, -20907, -20907, -20907, 27758, 27758, 27758, 27758, - -3799, -3799, -3799, -3799, -15690, -15690, -15690, -15690, - -171, -171, -171, -171, 622, 622, 622, 622, - 1577, 1577, 1577, 1577, 182, 182, 182, 182, - -5827, -5827, 17363, 17363, -26360, -26360, -29057, -29057, - 5571, 5571, -1102, -1102, 21438, 21438, -26242, -26242, - 573, 573, -1325, -1325, 264, 264, 383, 383, - -829, -829, 1458, 1458, -1602, -1602, -130, -130, - -5689, -6516, 1496, 30967, -23565, 20179, 20710, 25080, - -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, - 32010, 10631, 29175, -28762, -18486, 17560, -14430, -5276, - -1103, 555, -1251, 1550, 422, 177, -291, 1574, - -246, 1159, -777, -602, -1590, -872, 418, -156, - 11182, 13387, -14233, -21655, 13131, -4587, 23092, 5493, - -32502, 30317, -18741, 12639, 20100, 18525, 19529, -12619, - 430, 843, 871, 105, 587, -235, -460, 1653, - 778, -147, 1483, 1119, 644, 349, 329, -75, - 787, 787, 787, 787, 787, 787, 787, 787, - 787, 787, 787, 787, 787, 787, 787, 787, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, - -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, - 287, 287, 287, 287, 287, 287, 287, 287, - 202, 202, 202, 202, 202, 202, 202, 202, - 10690, 10690, 10690, 10690, 1358, 1358, 1358, 1358, - -11202, -11202, -11202, -11202, 31164, 31164, 31164, 31164, - 962, 962, 962, 962, -1202, -1202, -1202, -1202, - -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, - -28073, -28073, 24313, 24313, -10532, -10532, 8800, 8800, - 18426, 18426, 8859, 8859, 26675, 26675, -16163, -16163, - -681, -681, 1017, 1017, 732, 732, 608, 608, - -1542, -1542, 411, 411, -205, -205, -1571, -1571, - 19883, -28250, -15887, -8898, -28309, 9075, -30199, 18249, - 13426, 14017, -29156, -12757, 16832, 4311, -24155, -17915, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -31183, 25435, -7382, 24391, -20927, 10946, 24214, 16989, - 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, - 817, 603, 1322, -1465, -1215, 1218, -874, -1187, - -1185, -1278, -1510, -870, -108, 996, 958, 1522, - 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, - -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, - 1097, 610, -1285, 384, -136, -1335, 220, -1659, - -1530, 794, -854, 478, -308, 991, -1460, 1628, - -#define _16XSHIFT 624 - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT -}}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.h deleted file mode 100644 index f95899cd8e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/consts.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef CONSTS_H -#define CONSTS_H - -#include "params.h" - -#define _16XQ 0 -#define _16XQINV 16 -#define _16XV 32 -#define _16XFLO 48 -#define _16XFHI 64 -#define _16XMONTSQLO 80 -#define _16XMONTSQHI 96 -#define _16XMASK 112 -#define _REVIDXB 128 -#define _REVIDXD 144 -#define _ZETAS_EXP 160 -#define _16XSHIFT 624 - -/* The C ABI on MacOS exports all symbols with a leading - * underscore. This means that any symbols we refer to from - * C files (functions) can't be found, and all symbols we - * refer to from ASM also can't be found. - * - * This define helps us get around this - */ -#ifdef __ASSEMBLER__ -#if defined(__WIN32__) || defined(__APPLE__) -#define decorate(s) _##s -#define cdecl2(s) decorate(s) -#define cdecl(s) cdecl2(KYBER_NAMESPACE(##s)) -#else -#define cdecl(s) KYBER_NAMESPACE(##s) -#endif -#endif - -#ifndef __ASSEMBLER__ -#include "align.h" -typedef ALIGNED_INT16(640) qdata_t; -#define qdata KYBER_NAMESPACE(qdata) -extern const qdata_t qdata; -#endif - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.c deleted file mode 100644 index c4b2b3a89f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/indcpa.c +++ /dev/null @@ -1,568 +0,0 @@ -#include -#include -#include -#include -#include "align.h" -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "cbd.h" -#include "rejsample.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk and the -* public seed used to generate the matrix A. -* The polynomial coefficients in pk are assumed to -* lie in the invertal [0,q], i.e. pk must be reduced -* by polyvec_reduce(). -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key. -* The polynomial coefficients in sk are assumed to -* lie in the invertal [0,q], i.e. sk must be reduced -* by polyvec_reduce(). -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v. -* The polynomial coefficients in b and v are assumed to -* lie in the invertal [0,q], i.e. b and v must be reduced -* by polyvec_reduce() and poly_reduce(), respectively. -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output array -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos <= buflen - 3) { // buflen is always at least 3 - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if KYBER_K == 2 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 1; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 1; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[1].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[1].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[1].vec[0]); - poly_nttunpack(&a[1].vec[1]); - shake128x4_inc_ctx_release(&state); -} -#elif KYBER_K == 3 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128incctx state1x; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 0; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 0; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[0].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[0].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[0].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[0].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[0].vec[2]); - poly_nttunpack(&a[1].vec[0]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 2; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 2; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 2; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 2; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[1].vec[1].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[1].vec[2].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[2].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[2].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[1].vec[1].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[1].vec[2].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[2].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[2].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - shake128x4_inc_ctx_release(&state); - - poly_nttunpack(&a[1].vec[1]); - poly_nttunpack(&a[1].vec[2]); - poly_nttunpack(&a[2].vec[0]); - poly_nttunpack(&a[2].vec[1]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - buf[0].coeffs[32] = 2; - buf[0].coeffs[33] = 2; - - shake128_inc_init(&state1x); - shake128_absorb_once(&state1x, buf[0].coeffs, 34); - shake128_squeezeblocks(buf[0].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state1x); - ctr0 = rej_uniform_avx(a[2].vec[2].coeffs, buf[0].coeffs); - while(ctr0 < KYBER_N) { - shake128_squeezeblocks(buf[0].coeffs, 1, &state1x); - ctr0 += rej_uniform(a[2].vec[2].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - } - shake128_inc_ctx_release(&state1x); - - poly_nttunpack(&a[2].vec[2]); -} -#elif KYBER_K == 4 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int i, ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128x4_inc_init(&state); - - for(i=0;i<4;i++) { - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = i; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = i; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = i; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = i; - buf[3].coeffs[33] = 3; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = i; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = i; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = i; - buf[3].coeffs[32] = 3; - buf[3].coeffs[33] = i; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[i].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[i].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[i].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[i].vec[3].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[i].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[i].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[i].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[i].vec[3].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[i].vec[0]); - poly_nttunpack(&a[i].vec[1]); - poly_nttunpack(&a[i].vec[2]); - poly_nttunpack(&a[i].vec[3]); - } - shake128x4_inc_ctx_release(&state); -} -#endif - -/************************************************* -* Name: indcpa_keypair_derand -* -* Description: Generates public and private key for the CPA-secure -* public-key encryption scheme underlying Kyber -* -* Arguments: - uint8_t *pk: pointer to output public key -* (of length KYBER_INDCPA_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (of length KYBER_INDCPA_SECRETKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (of length KYBER_SYMBYTES bytes) -**************************************************/ -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]) -{ - unsigned int i; - uint8_t buf[2*KYBER_SYMBYTES]; - const uint8_t *publicseed = buf; - const uint8_t *noiseseed = buf + KYBER_SYMBYTES; - polyvec a[KYBER_K], e, pkpv, skpv; - - memcpy(buf, coins, KYBER_SYMBYTES); - buf[KYBER_SYMBYTES] = KYBER_K; - hash_g(buf, buf, KYBER_SYMBYTES+1); - - gen_a(a, publicseed); - -#if KYBER_K == 2 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, e.vec+0, e.vec+1, noiseseed, 0, 1, 2, 3); -#elif KYBER_K == 3 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, e.vec+0, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+1, e.vec+2, pkpv.vec+0, pkpv.vec+1, noiseseed, 4, 5, 6, 7); -#elif KYBER_K == 4 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, skpv.vec+3, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+0, e.vec+1, e.vec+2, e.vec+3, noiseseed, 4, 5, 6, 7); -#endif - - polyvec_ntt(&skpv); - polyvec_reduce(&skpv); - polyvec_ntt(&e); - - // matrix-vector multiplication - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/invntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/invntt.S deleted file mode 100644 index 76d4189996..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/invntt.S +++ /dev/null @@ -1,193 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" -.include "fq.inc" - -.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 -vpsubw %ymm\rl0,%ymm\rh0,%ymm12 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 -vpsubw %ymm\rl1,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl0,%ymm12,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 -vpsubw %ymm\rl2,%ymm\rh2,%ymm14 - -vpmullw %ymm\zl0,%ymm13,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 -vpsubw %ymm\rl3,%ymm\rh3,%ymm15 - -vpmullw %ymm\zl1,%ymm14,%ymm\rh2 -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 -vpmullw %ymm\zl1,%ymm15,%ymm\rh3 - -vpmulhw %ymm\zh0,%ymm12,%ymm12 -vpmulhw %ymm\zh0,%ymm13,%ymm13 - -vpmulhw %ymm\zh1,%ymm14,%ymm14 -vpmulhw %ymm\zh1,%ymm15,%ymm15 - -vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 - -vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 - -# - -# - -vpsubw %ymm\rh0,%ymm12,%ymm\rh0 - -vpsubw %ymm\rh1,%ymm13,%ymm\rh1 - -vpsubw %ymm\rh2,%ymm14,%ymm\rh2 -vpsubw %ymm\rh3,%ymm15,%ymm\rh3 -.endm - -.macro intt_levels0t5 off -/* level 0 */ -vmovdqa _16XFLO*2(%rsi),%ymm2 -vmovdqa _16XFHI*2(%rsi),%ymm3 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -fqmulprecomp 2,3,4 -fqmulprecomp 2,3,6 -fqmulprecomp 2,3,5 -fqmulprecomp 2,3,7 - -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 - -fqmulprecomp 2,3,8 -fqmulprecomp 2,3,10 -fqmulprecomp 2,3,9 -fqmulprecomp 2,3,11 - -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm12 -vpshufb %ymm12,%ymm15,%ymm15 -vpshufb %ymm12,%ymm1,%ymm1 -vpshufb %ymm12,%ymm2,%ymm2 -vpshufb %ymm12,%ymm3,%ymm3 - -butterfly 4,5,8,9,6,7,10,11,15,1,2,3 - -/* level 1 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm1 -vpshufb %ymm1,%ymm2,%ymm2 -vpshufb %ymm1,%ymm3,%ymm3 - -butterfly 4,5,6,7,8,9,10,11,2,2,3,3 - -shuffle1 4,5,3,5 -shuffle1 6,7,4,7 -shuffle1 8,9,6,9 -shuffle1 10,11,8,11 - -/* level 2 */ -vmovdqa _REVIDXD*2(%rsi),%ymm12 -vpermd (_ZETAS_EXP+(1-\off)*224+112)*2(%rsi),%ymm12,%ymm2 -vpermd (_ZETAS_EXP+(1-\off)*224+128)*2(%rsi),%ymm12,%ymm10 - -butterfly 3,4,6,8,5,7,9,11,2,2,10,10 - -vmovdqa _16XV*2(%rsi),%ymm1 -red16 3 - -shuffle2 3,4,10,4 -shuffle2 6,8,3,8 -shuffle2 5,7,6,7 -shuffle2 9,11,5,11 - -/* level 3 */ -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+80)*2(%rsi),%ymm2 -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+96)*2(%rsi),%ymm9 - -butterfly 10,3,6,5,4,8,7,11,2,2,9,9 - -shuffle4 10,3,9,3 -shuffle4 6,5,10,5 -shuffle4 4,8,6,8 -shuffle4 7,11,4,11 - -/* level 4 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+48)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+64)*2(%rsi),%ymm7 - -butterfly 9,10,6,4,3,5,8,11,2,2,7,7 - -red16 9 - -shuffle8 9,10,7,10 -shuffle8 6,4,9,4 -shuffle8 3,5,6,5 -shuffle8 8,11,3,11 - -/* level 5 */ -vmovdqa (_ZETAS_EXP+(1-\off)*224+16)*2(%rsi),%ymm2 -vmovdqa (_ZETAS_EXP+(1-\off)*224+32)*2(%rsi),%ymm8 - -butterfly 7,9,6,3,10,4,5,11,2,2,8,8 - -vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.macro intt_level6 off -/* level 6 */ -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm2 - -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm3 - -butterfly 4,5,6,7,8,9,10,11 - -.if \off == 0 -red16 4 -.endif - -vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.text -.global cdecl(invntt_avx) -cdecl(invntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -intt_levels0t5 0 -intt_levels0t5 1 - -intt_level6 0 -intt_level6 1 -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.S deleted file mode 100644 index 0ce7b41297..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.S +++ /dev/null @@ -1,189 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" - -.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 -vpmullw %ymm\zl0,%ymm\rh0,%ymm12 -vpmullw %ymm\zl0,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl1,%ymm\rh2,%ymm14 -vpmullw %ymm\zl1,%ymm\rh3,%ymm15 - -vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 -vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 -.endm - -.macro reduce -vpmulhw %ymm0,%ymm12,%ymm12 -vpmulhw %ymm0,%ymm13,%ymm13 - -vpmulhw %ymm0,%ymm14,%ymm14 -vpmulhw %ymm0,%ymm15,%ymm15 -.endm - -.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln -vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 - -vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 -vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 - -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 -vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 - -vpsubw %ymm12,%ymm\rln,%ymm\rln -vpaddw %ymm12,%ymm\rh0,%ymm\rh0 -vpsubw %ymm13,%ymm\rl0,%ymm\rl0 - -vpaddw %ymm13,%ymm\rh1,%ymm\rh1 -vpsubw %ymm14,%ymm\rl1,%ymm\rl1 -vpaddw %ymm14,%ymm\rh2,%ymm\rh2 - -vpsubw %ymm15,%ymm\rl2,%ymm\rl2 -vpaddw %ymm15,%ymm\rh3,%ymm\rh3 -.endm - -.macro level0 off -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm15 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.macro levels1t6 off -/* level 1 */ -vmovdqa (_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 -vmovdqa (_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -/* level 2 */ -shuffle8 5,10,7,10 -shuffle8 6,11,5,11 - -vmovdqa (_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 - -mul 7,10,5,11 - -shuffle8 3,8,6,8 -shuffle8 4,9,3,9 - -reduce -update 4,6,8,3,9,7,10,5,11 - -/* level 3 */ -shuffle4 8,5,9,5 -shuffle4 3,11,8,11 - -vmovdqa (_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 - -mul 9,5,8,11 - -shuffle4 4,7,3,7 -shuffle4 6,10,4,10 - -reduce -update 6,3,7,4,10,9,5,8,11 - -/* level 4 */ -shuffle2 7,8,10,8 -shuffle2 4,11,7,11 - -vmovdqa (_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 - -mul 10,8,7,11 - -shuffle2 6,9,4,9 -shuffle2 3,5,6,5 - -reduce -update 3,4,9,6,5,10,8,7,11 - -/* level 5 */ -shuffle1 9,7,5,7 -shuffle1 6,11,9,11 - -vmovdqa (_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 - -mul 5,7,9,11 - -shuffle1 3,10,6,10 -shuffle1 4,8,3,8 - -reduce -update 4,6,10,3,8,5,7,9,11 - -/* level 6 */ -vmovdqa (_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 -vmovdqa (_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 -vmovdqa (_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 - -mul 10,3,9,11,14,15,8,2 - -reduce -update 8,4,6,5,7,10,3,9,11 - -vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.text -.global cdecl(ntt_avx) -cdecl(ntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -level0 0 -level0 1 - -levels1t6 0 -levels1t6 1 - -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.h deleted file mode 100644 index a4f48e343b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/ntt.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include - -#define ntt_avx KYBER_NAMESPACE(ntt_avx) -void ntt_avx(__m256i *r, const __m256i *qdata); -#define invntt_avx KYBER_NAMESPACE(invntt_avx) -void invntt_avx(__m256i *r, const __m256i *qdata); - -#define nttpack_avx KYBER_NAMESPACE(nttpack_avx) -void nttpack_avx(__m256i *r, const __m256i *qdata); -#define nttunpack_avx KYBER_NAMESPACE(nttunpack_avx) -void nttunpack_avx(__m256i *r, const __m256i *qdata); - -#define basemul_avx KYBER_NAMESPACE(basemul_avx) -void basemul_avx(__m256i *r, - const __m256i *a, - const __m256i *b, - const __m256i *qdata); - -#define ntttobytes_avx KYBER_NAMESPACE(ntttobytes_avx) -void ntttobytes_avx(uint8_t *r, const __m256i *a, const __m256i *qdata); -#define nttfrombytes_avx KYBER_NAMESPACE(nttfrombytes_avx) -void nttfrombytes_avx(__m256i *r, const uint8_t *a, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/params.h deleted file mode 100644 index ecfabce4a5..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/params.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - -//#define KYBER_90S /* Uncomment this if you want the 90S variant */ - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber512_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_avx2_##s -#endif -#elif (KYBER_K == 3) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber768_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_avx2_##s -#endif -#elif (KYBER_K == 4) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber1024_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_avx2_##s -#endif -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.c deleted file mode 100644 index 681fd6d23e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.c +++ /dev/null @@ -1,519 +0,0 @@ -#include -#include -#include -#include "align.h" -#include "fips202x4.h" -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -#if (KYBER_POLYCOMPRESSEDBYTES == 128) -void poly_compress(uint8_t r[128], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 9); - const __m256i mask = _mm256_set1_epi16(15); - const __m256i shift2 = _mm256_set1_epi16((16 << 8) + 1); - const __m256i permdidx = _mm256_set_epi32(7,3,6,2,5,1,4,0); - - for(i=0;ivec[4*i+0]); - f1 = _mm256_load_si256(&a->vec[4*i+1]); - f2 = _mm256_load_si256(&a->vec[4*i+2]); - f3 = _mm256_load_si256(&a->vec[4*i+3]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f2 = _mm256_mulhi_epi16(f2,v); - f3 = _mm256_mulhi_epi16(f3,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f2 = _mm256_mulhrs_epi16(f2,shift1); - f3 = _mm256_mulhrs_epi16(f3,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f2 = _mm256_and_si256(f2,mask); - f3 = _mm256_and_si256(f3,mask); - f0 = _mm256_packus_epi16(f0,f1); - f2 = _mm256_packus_epi16(f2,f3); - f0 = _mm256_maddubs_epi16(f0,shift2); - f2 = _mm256_maddubs_epi16(f2,shift2); - f0 = _mm256_packus_epi16(f0,f2); - f0 = _mm256_permutevar8x32_epi32(f0,permdidx); - _mm256_storeu_si256((__m256i *)&r[32*i],f0); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[128]) -{ - unsigned int i; - __m128i t; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(7,7,7,7,6,6,6,6,5,5,5,5,4,4,4,4, - 3,3,3,3,2,2,2,2,1,1,1,1,0,0,0,0); - const __m256i mask = _mm256_set1_epi32(0x00F0000F); - const __m256i shift = _mm256_set1_epi32((128 << 16) + 2048); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) -void poly_compress(uint8_t r[160], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 10); - const __m256i mask = _mm256_set1_epi16(31); - const __m256i shift2 = _mm256_set1_epi16((32 << 8) + 1); - const __m256i shift3 = _mm256_set1_epi32((1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8,-1,-1,-1,-1,-1, 4, 3, 2, 1, 0,-1,12,11,10, 9, - -1,12,11,10, 9, 8,-1,-1,-1,-1,-1 ,4, 3, 2, 1, 0); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f0 = _mm256_packus_epi16(f0,f1); - f0 = _mm256_maddubs_epi16(f0,shift2); // a0 a1 a2 a3 b0 b1 b2 b3 a4 a5 a6 a7 b4 b5 b6 b7 - f0 = _mm256_madd_epi16(f0,shift3); // a0 a1 b0 b1 a2 a3 b2 b3 - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srlv_epi64(f0,sllvdidx); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[160]) -{ - unsigned int i; - __m128i t; - __m256i f; - int16_t ti; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(9,9,9,8,8,8,8,7,7,6,6,6,6,5,5,5, - 4,4,4,3,3,3,3,2,2,1,1,1,1,0,0,0); - const __m256i mask = _mm256_set_epi16(248,1984,62,496,3968,124,992,31, - 248,1984,62,496,3968,124,992,31); - const __m256i shift = _mm256_set_epi16(128,16,512,64,8,256,32,1024, - 128,16,512,64,8,256,32,1024); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial in NTT representation. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). The coefficients are orderd as output by -* poly_ntt(); the serialized output coefficients are in bitreversed -* order. -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - ntttobytes_avx(r, a->vec, qdata.vec); -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - nttfrombytes_avx(r->vec, a, qdata.vec); -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly * restrict r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ -#if (KYBER_INDCPA_MSGBYTES != 32) -#error "KYBER_INDCPA_MSGBYTES must be equal to 32!" -#endif - __m256i f, g0, g1, g2, g3, h0, h1, h2, h3; - const __m256i shift = _mm256_broadcastsi128_si256(_mm_set_epi32(0,1,2,3)); - const __m256i idx = _mm256_broadcastsi128_si256(_mm_set_epi8(15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0)); - const __m256i hqs = _mm256_set1_epi16((KYBER_Q+1)/2); - -#define FROMMSG64(i) \ - g3 = _mm256_shuffle_epi32(f,0x55*i); \ - g3 = _mm256_sllv_epi32(g3,shift); \ - g3 = _mm256_shuffle_epi8(g3,idx); \ - g0 = _mm256_slli_epi16(g3,12); \ - g1 = _mm256_slli_epi16(g3,8); \ - g2 = _mm256_slli_epi16(g3,4); \ - g0 = _mm256_srai_epi16(g0,15); \ - g1 = _mm256_srai_epi16(g1,15); \ - g2 = _mm256_srai_epi16(g2,15); \ - g3 = _mm256_srai_epi16(g3,15); \ - g0 = _mm256_and_si256(g0,hqs); /* 19 18 17 16 3 2 1 0 */ \ - g1 = _mm256_and_si256(g1,hqs); /* 23 22 21 20 7 6 5 4 */ \ - g2 = _mm256_and_si256(g2,hqs); /* 27 26 25 24 11 10 9 8 */ \ - g3 = _mm256_and_si256(g3,hqs); /* 31 30 29 28 15 14 13 12 */ \ - h0 = _mm256_unpacklo_epi64(g0,g1); \ - h2 = _mm256_unpackhi_epi64(g0,g1); \ - h1 = _mm256_unpacklo_epi64(g2,g3); \ - h3 = _mm256_unpackhi_epi64(g2,g3); \ - g0 = _mm256_permute2x128_si256(h0,h1,0x20); \ - g2 = _mm256_permute2x128_si256(h0,h1,0x31); \ - g1 = _mm256_permute2x128_si256(h2,h3,0x20); \ - g3 = _mm256_permute2x128_si256(h2,h3,0x31); \ - _mm256_store_si256(&r->vec[0+2*i+0],g0); \ - _mm256_store_si256(&r->vec[0+2*i+1],g1); \ - _mm256_store_si256(&r->vec[8+2*i+0],g2); \ - _mm256_store_si256(&r->vec[8+2*i+1],g3) - - f = _mm256_loadu_si256((__m256i *)msg); - FROMMSG64(0); - FROMMSG64(1); - FROMMSG64(2); - FROMMSG64(3); -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *msg: pointer to output message -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly * restrict a) -{ - unsigned int i; - uint32_t small; - __m256i f0, f1, g0, g1; - const __m256i hq = _mm256_set1_epi16((KYBER_Q - 1)/2); - const __m256i hhq = _mm256_set1_epi16((KYBER_Q - 1)/4); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_sub_epi16(hq, f0); - f1 = _mm256_sub_epi16(hq, f1); - g0 = _mm256_srai_epi16(f0, 15); - g1 = _mm256_srai_epi16(f1, 15); - f0 = _mm256_xor_si256(f0, g0); - f1 = _mm256_xor_si256(f1, g1); - f0 = _mm256_sub_epi16(f0, hhq); - f1 = _mm256_sub_epi16(f1, hhq); - f0 = _mm256_packs_epi16(f0, f1); - f0 = _mm256_permute4x64_epi64(f0, 0xD8); - small = _mm256_movemask_epi8(f0); - memcpy(&msg[4*i], &small, 4); - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA1*KYBER_N/4+32) buf; // +32 bytes as required by poly_cbd_eta1 - prf(buf.coeffs, KYBER_ETA1*KYBER_N/4, seed, nonce); - poly_cbd_eta1(r, buf.vec); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA2*KYBER_N/4) buf; - prf(buf.coeffs, KYBER_ETA2*KYBER_N/4, seed, nonce); - poly_cbd_eta2(r, buf.vec); -} - -#ifndef KYBER_90S -#define NOISE_NBLOCKS ((KYBER_ETA1*KYBER_N/4+SHAKE256_RATE-1)/SHAKE256_RATE) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta1(r2, buf[2].vec); - poly_cbd_eta1(r3, buf[3].vec); -} - -#if KYBER_K == 2 -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta2(r2, buf[2].vec); - poly_cbd_eta2(r3, buf[3].vec); -} -#endif -#endif - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place. -* Input coefficients assumed to be in normal order, -* output coefficients are in special order that is natural -* for the vectorization. Input coefficients are assumed to be -* bounded by q in absolute value, output coefficients are bounded -* by 16118 in absolute value. -* -* Arguments: - poly *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* Input coefficients assumed to be in special order from vectorized -* forward ntt, output in normal order. Input coefficients can be -* arbitrary 16-bit integers, output coefficients are bounded by 14870 -* in absolute value. -* -* Arguments: - poly *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt_avx(r->vec, qdata.vec); -} - -void poly_nttunpack(poly *r) -{ - nttunpack_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain. -* One of the input polynomials needs to have coefficients -* bounded by q, the other polynomial can have arbitrary -* coefficients. Output coefficients are bounded by 6656. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - basemul_avx(r->vec, a->vec, b->vec, qdata.vec); -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - tomont_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - reduce_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_add_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_sub_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.h deleted file mode 100644 index 6a9cf71c70..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/poly.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "align.h" -#include "params.h" - -typedef ALIGNED_INT16(KYBER_N) poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#ifndef KYBER_90S -#define poly_getnoise_eta1_4x KYBER_NAMESPACE(poly_getnoise_eta2_4x) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); - -#if KYBER_K == 2 -#define poly_getnoise_eta1122_4x KYBER_NAMESPACE(poly_getnoise_eta1122_4x) -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); -#endif -#endif - - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_nttunpack KYBER_NAMESPACE(poly_nttunpack) -void poly_nttunpack(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.c deleted file mode 100644 index a0174b7b3f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.c +++ /dev/null @@ -1,307 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) -static void poly_compress10(uint8_t r[320], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(15); - const __m256i shift1 = _mm256_set1_epi16(1 << 12); - const __m256i mask = _mm256_set1_epi16(1023); - const __m256i shift2 = _mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8, 4, 3, 2, 1, 0,-1,-1,-1,-1,-1,-1,12,11,10, 9, - -1,-1,-1,-1,-1,-1,12,11,10, 9, 8, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srli_epi64(f0,12); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blend_epi16(t0,t1,0xE0); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -static void poly_decompress10(poly * restrict r, const uint8_t a[320+12]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_set1_epi32((KYBER_Q << 16) + 4*KYBER_Q); - const __m256i shufbidx = _mm256_set_epi8(11,10,10, 9, 9, 8, 8, 7, - 6, 5, 5, 4, 4, 3, 3, 2, - 9, 8, 8, 7, 7, 6, 6, 5, - 4, 3, 3, 2, 2, 1, 1, 0); - const __m256i sllvdidx = _mm256_set1_epi64x(4); - const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) -static void poly_compress11(uint8_t r[352+2], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(36); - const __m256i shift1 = _mm256_set1_epi16(1 << 13); - const __m256i mask = _mm256_set1_epi16(2047); - const __m256i shift2 = _mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(10); - const __m256i srlvqidx = _mm256_set_epi64x(30,10,30,10); - const __m256i shufbidx = _mm256_set_epi8( 4, 3, 2, 1, 0, 0,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, - -1,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f1 = _mm256_bsrli_epi128(f0,8); - f0 = _mm256_srlv_epi64(f0,srlvqidx); - f1 = _mm256_slli_epi64(f1,34); - f0 = _mm256_add_epi64(f0,f1); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[22*i+ 0],t0); - _mm_storel_epi64((__m128i *)&r[22*i+16],t1); - } -} - -static void poly_decompress11(poly * restrict r, const uint8_t a[352+10]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(13,12,12,11,10, 9, 9, 8, - 8, 7, 6, 5, 5, 4, 4, 3, - 10, 9, 9, 8, 7, 6, 6, 5, - 5, 4, 3, 2, 2, 1, 1, 0); - const __m256i srlvdidx = _mm256_set_epi32(0,0,1,0,0,0,1,0); - const __m256i srlvqidx = _mm256_set_epi64x(2,0,2,0); - const __m256i shift = _mm256_set_epi16(4,32,1,8,32,1,4,32,4,32,1,8,32,1,4,32); - const __m256i mask = _mm256_set1_epi16(32752); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i]); -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i],&a[320*i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i],&a[352*i]); -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements in a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly tmp; - - poly_basemul_montgomery(r,&a->vec[0],&b->vec[0]); - for(i=1;ivec[i],&b->vec[i]); - poly_add(r,r,&tmp); - } -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.h deleted file mode 100644 index 2ce23c31ff..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/reduce.h deleted file mode 100644 index 5368185b5f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/reduce.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include "params.h" -#include - -#define reduce_avx KYBER_NAMESPACE(reduce_avx) -void reduce_avx(__m256i *r, const __m256i *qdata); -#define tomont_avx KYBER_NAMESPACE(tomont_avx) -void tomont_avx(__m256i *r, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.c deleted file mode 100644 index 9060a44cb9..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.c +++ /dev/null @@ -1,398 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "consts.h" -#include "rejsample.h" - -//#define BMI - -#ifndef BMI -static const uint8_t idx[256][8] = { - {-1, -1, -1, -1, -1, -1, -1, -1}, - { 0, -1, -1, -1, -1, -1, -1, -1}, - { 2, -1, -1, -1, -1, -1, -1, -1}, - { 0, 2, -1, -1, -1, -1, -1, -1}, - { 4, -1, -1, -1, -1, -1, -1, -1}, - { 0, 4, -1, -1, -1, -1, -1, -1}, - { 2, 4, -1, -1, -1, -1, -1, -1}, - { 0, 2, 4, -1, -1, -1, -1, -1}, - { 6, -1, -1, -1, -1, -1, -1, -1}, - { 0, 6, -1, -1, -1, -1, -1, -1}, - { 2, 6, -1, -1, -1, -1, -1, -1}, - { 0, 2, 6, -1, -1, -1, -1, -1}, - { 4, 6, -1, -1, -1, -1, -1, -1}, - { 0, 4, 6, -1, -1, -1, -1, -1}, - { 2, 4, 6, -1, -1, -1, -1, -1}, - { 0, 2, 4, 6, -1, -1, -1, -1}, - { 8, -1, -1, -1, -1, -1, -1, -1}, - { 0, 8, -1, -1, -1, -1, -1, -1}, - { 2, 8, -1, -1, -1, -1, -1, -1}, - { 0, 2, 8, -1, -1, -1, -1, -1}, - { 4, 8, -1, -1, -1, -1, -1, -1}, - { 0, 4, 8, -1, -1, -1, -1, -1}, - { 2, 4, 8, -1, -1, -1, -1, -1}, - { 0, 2, 4, 8, -1, -1, -1, -1}, - { 6, 8, -1, -1, -1, -1, -1, -1}, - { 0, 6, 8, -1, -1, -1, -1, -1}, - { 2, 6, 8, -1, -1, -1, -1, -1}, - { 0, 2, 6, 8, -1, -1, -1, -1}, - { 4, 6, 8, -1, -1, -1, -1, -1}, - { 0, 4, 6, 8, -1, -1, -1, -1}, - { 2, 4, 6, 8, -1, -1, -1, -1}, - { 0, 2, 4, 6, 8, -1, -1, -1}, - {10, -1, -1, -1, -1, -1, -1, -1}, - { 0, 10, -1, -1, -1, -1, -1, -1}, - { 2, 10, -1, -1, -1, -1, -1, -1}, - { 0, 2, 10, -1, -1, -1, -1, -1}, - { 4, 10, -1, -1, -1, -1, -1, -1}, - { 0, 4, 10, -1, -1, -1, -1, -1}, - { 2, 4, 10, -1, -1, -1, -1, -1}, - { 0, 2, 4, 10, -1, -1, -1, -1}, - { 6, 10, -1, -1, -1, -1, -1, -1}, - { 0, 6, 10, -1, -1, -1, -1, -1}, - { 2, 6, 10, -1, -1, -1, -1, -1}, - { 0, 2, 6, 10, -1, -1, -1, -1}, - { 4, 6, 10, -1, -1, -1, -1, -1}, - { 0, 4, 6, 10, -1, -1, -1, -1}, - { 2, 4, 6, 10, -1, -1, -1, -1}, - { 0, 2, 4, 6, 10, -1, -1, -1}, - { 8, 10, -1, -1, -1, -1, -1, -1}, - { 0, 8, 10, -1, -1, -1, -1, -1}, - { 2, 8, 10, -1, -1, -1, -1, -1}, - { 0, 2, 8, 10, -1, -1, -1, -1}, - { 4, 8, 10, -1, -1, -1, -1, -1}, - { 0, 4, 8, 10, -1, -1, -1, -1}, - { 2, 4, 8, 10, -1, -1, -1, -1}, - { 0, 2, 4, 8, 10, -1, -1, -1}, - { 6, 8, 10, -1, -1, -1, -1, -1}, - { 0, 6, 8, 10, -1, -1, -1, -1}, - { 2, 6, 8, 10, -1, -1, -1, -1}, - { 0, 2, 6, 8, 10, -1, -1, -1}, - { 4, 6, 8, 10, -1, -1, -1, -1}, - { 0, 4, 6, 8, 10, -1, -1, -1}, - { 2, 4, 6, 8, 10, -1, -1, -1}, - { 0, 2, 4, 6, 8, 10, -1, -1}, - {12, -1, -1, -1, -1, -1, -1, -1}, - { 0, 12, -1, -1, -1, -1, -1, -1}, - { 2, 12, -1, -1, -1, -1, -1, -1}, - { 0, 2, 12, -1, -1, -1, -1, -1}, - { 4, 12, -1, -1, -1, -1, -1, -1}, - { 0, 4, 12, -1, -1, -1, -1, -1}, - { 2, 4, 12, -1, -1, -1, -1, -1}, - { 0, 2, 4, 12, -1, -1, -1, -1}, - { 6, 12, -1, -1, -1, -1, -1, -1}, - { 0, 6, 12, -1, -1, -1, -1, -1}, - { 2, 6, 12, -1, -1, -1, -1, -1}, - { 0, 2, 6, 12, -1, -1, -1, -1}, - { 4, 6, 12, -1, -1, -1, -1, -1}, - { 0, 4, 6, 12, -1, -1, -1, -1}, - { 2, 4, 6, 12, -1, -1, -1, -1}, - { 0, 2, 4, 6, 12, -1, -1, -1}, - { 8, 12, -1, -1, -1, -1, -1, -1}, - { 0, 8, 12, -1, -1, -1, -1, -1}, - { 2, 8, 12, -1, -1, -1, -1, -1}, - { 0, 2, 8, 12, -1, -1, -1, -1}, - { 4, 8, 12, -1, -1, -1, -1, -1}, - { 0, 4, 8, 12, -1, -1, -1, -1}, - { 2, 4, 8, 12, -1, -1, -1, -1}, - { 0, 2, 4, 8, 12, -1, -1, -1}, - { 6, 8, 12, -1, -1, -1, -1, -1}, - { 0, 6, 8, 12, -1, -1, -1, -1}, - { 2, 6, 8, 12, -1, -1, -1, -1}, - { 0, 2, 6, 8, 12, -1, -1, -1}, - { 4, 6, 8, 12, -1, -1, -1, -1}, - { 0, 4, 6, 8, 12, -1, -1, -1}, - { 2, 4, 6, 8, 12, -1, -1, -1}, - { 0, 2, 4, 6, 8, 12, -1, -1}, - {10, 12, -1, -1, -1, -1, -1, -1}, - { 0, 10, 12, -1, -1, -1, -1, -1}, - { 2, 10, 12, -1, -1, -1, -1, -1}, - { 0, 2, 10, 12, -1, -1, -1, -1}, - { 4, 10, 12, -1, -1, -1, -1, -1}, - { 0, 4, 10, 12, -1, -1, -1, -1}, - { 2, 4, 10, 12, -1, -1, -1, -1}, - { 0, 2, 4, 10, 12, -1, -1, -1}, - { 6, 10, 12, -1, -1, -1, -1, -1}, - { 0, 6, 10, 12, -1, -1, -1, -1}, - { 2, 6, 10, 12, -1, -1, -1, -1}, - { 0, 2, 6, 10, 12, -1, -1, -1}, - { 4, 6, 10, 12, -1, -1, -1, -1}, - { 0, 4, 6, 10, 12, -1, -1, -1}, - { 2, 4, 6, 10, 12, -1, -1, -1}, - { 0, 2, 4, 6, 10, 12, -1, -1}, - { 8, 10, 12, -1, -1, -1, -1, -1}, - { 0, 8, 10, 12, -1, -1, -1, -1}, - { 2, 8, 10, 12, -1, -1, -1, -1}, - { 0, 2, 8, 10, 12, -1, -1, -1}, - { 4, 8, 10, 12, -1, -1, -1, -1}, - { 0, 4, 8, 10, 12, -1, -1, -1}, - { 2, 4, 8, 10, 12, -1, -1, -1}, - { 0, 2, 4, 8, 10, 12, -1, -1}, - { 6, 8, 10, 12, -1, -1, -1, -1}, - { 0, 6, 8, 10, 12, -1, -1, -1}, - { 2, 6, 8, 10, 12, -1, -1, -1}, - { 0, 2, 6, 8, 10, 12, -1, -1}, - { 4, 6, 8, 10, 12, -1, -1, -1}, - { 0, 4, 6, 8, 10, 12, -1, -1}, - { 2, 4, 6, 8, 10, 12, -1, -1}, - { 0, 2, 4, 6, 8, 10, 12, -1}, - {14, -1, -1, -1, -1, -1, -1, -1}, - { 0, 14, -1, -1, -1, -1, -1, -1}, - { 2, 14, -1, -1, -1, -1, -1, -1}, - { 0, 2, 14, -1, -1, -1, -1, -1}, - { 4, 14, -1, -1, -1, -1, -1, -1}, - { 0, 4, 14, -1, -1, -1, -1, -1}, - { 2, 4, 14, -1, -1, -1, -1, -1}, - { 0, 2, 4, 14, -1, -1, -1, -1}, - { 6, 14, -1, -1, -1, -1, -1, -1}, - { 0, 6, 14, -1, -1, -1, -1, -1}, - { 2, 6, 14, -1, -1, -1, -1, -1}, - { 0, 2, 6, 14, -1, -1, -1, -1}, - { 4, 6, 14, -1, -1, -1, -1, -1}, - { 0, 4, 6, 14, -1, -1, -1, -1}, - { 2, 4, 6, 14, -1, -1, -1, -1}, - { 0, 2, 4, 6, 14, -1, -1, -1}, - { 8, 14, -1, -1, -1, -1, -1, -1}, - { 0, 8, 14, -1, -1, -1, -1, -1}, - { 2, 8, 14, -1, -1, -1, -1, -1}, - { 0, 2, 8, 14, -1, -1, -1, -1}, - { 4, 8, 14, -1, -1, -1, -1, -1}, - { 0, 4, 8, 14, -1, -1, -1, -1}, - { 2, 4, 8, 14, -1, -1, -1, -1}, - { 0, 2, 4, 8, 14, -1, -1, -1}, - { 6, 8, 14, -1, -1, -1, -1, -1}, - { 0, 6, 8, 14, -1, -1, -1, -1}, - { 2, 6, 8, 14, -1, -1, -1, -1}, - { 0, 2, 6, 8, 14, -1, -1, -1}, - { 4, 6, 8, 14, -1, -1, -1, -1}, - { 0, 4, 6, 8, 14, -1, -1, -1}, - { 2, 4, 6, 8, 14, -1, -1, -1}, - { 0, 2, 4, 6, 8, 14, -1, -1}, - {10, 14, -1, -1, -1, -1, -1, -1}, - { 0, 10, 14, -1, -1, -1, -1, -1}, - { 2, 10, 14, -1, -1, -1, -1, -1}, - { 0, 2, 10, 14, -1, -1, -1, -1}, - { 4, 10, 14, -1, -1, -1, -1, -1}, - { 0, 4, 10, 14, -1, -1, -1, -1}, - { 2, 4, 10, 14, -1, -1, -1, -1}, - { 0, 2, 4, 10, 14, -1, -1, -1}, - { 6, 10, 14, -1, -1, -1, -1, -1}, - { 0, 6, 10, 14, -1, -1, -1, -1}, - { 2, 6, 10, 14, -1, -1, -1, -1}, - { 0, 2, 6, 10, 14, -1, -1, -1}, - { 4, 6, 10, 14, -1, -1, -1, -1}, - { 0, 4, 6, 10, 14, -1, -1, -1}, - { 2, 4, 6, 10, 14, -1, -1, -1}, - { 0, 2, 4, 6, 10, 14, -1, -1}, - { 8, 10, 14, -1, -1, -1, -1, -1}, - { 0, 8, 10, 14, -1, -1, -1, -1}, - { 2, 8, 10, 14, -1, -1, -1, -1}, - { 0, 2, 8, 10, 14, -1, -1, -1}, - { 4, 8, 10, 14, -1, -1, -1, -1}, - { 0, 4, 8, 10, 14, -1, -1, -1}, - { 2, 4, 8, 10, 14, -1, -1, -1}, - { 0, 2, 4, 8, 10, 14, -1, -1}, - { 6, 8, 10, 14, -1, -1, -1, -1}, - { 0, 6, 8, 10, 14, -1, -1, -1}, - { 2, 6, 8, 10, 14, -1, -1, -1}, - { 0, 2, 6, 8, 10, 14, -1, -1}, - { 4, 6, 8, 10, 14, -1, -1, -1}, - { 0, 4, 6, 8, 10, 14, -1, -1}, - { 2, 4, 6, 8, 10, 14, -1, -1}, - { 0, 2, 4, 6, 8, 10, 14, -1}, - {12, 14, -1, -1, -1, -1, -1, -1}, - { 0, 12, 14, -1, -1, -1, -1, -1}, - { 2, 12, 14, -1, -1, -1, -1, -1}, - { 0, 2, 12, 14, -1, -1, -1, -1}, - { 4, 12, 14, -1, -1, -1, -1, -1}, - { 0, 4, 12, 14, -1, -1, -1, -1}, - { 2, 4, 12, 14, -1, -1, -1, -1}, - { 0, 2, 4, 12, 14, -1, -1, -1}, - { 6, 12, 14, -1, -1, -1, -1, -1}, - { 0, 6, 12, 14, -1, -1, -1, -1}, - { 2, 6, 12, 14, -1, -1, -1, -1}, - { 0, 2, 6, 12, 14, -1, -1, -1}, - { 4, 6, 12, 14, -1, -1, -1, -1}, - { 0, 4, 6, 12, 14, -1, -1, -1}, - { 2, 4, 6, 12, 14, -1, -1, -1}, - { 0, 2, 4, 6, 12, 14, -1, -1}, - { 8, 12, 14, -1, -1, -1, -1, -1}, - { 0, 8, 12, 14, -1, -1, -1, -1}, - { 2, 8, 12, 14, -1, -1, -1, -1}, - { 0, 2, 8, 12, 14, -1, -1, -1}, - { 4, 8, 12, 14, -1, -1, -1, -1}, - { 0, 4, 8, 12, 14, -1, -1, -1}, - { 2, 4, 8, 12, 14, -1, -1, -1}, - { 0, 2, 4, 8, 12, 14, -1, -1}, - { 6, 8, 12, 14, -1, -1, -1, -1}, - { 0, 6, 8, 12, 14, -1, -1, -1}, - { 2, 6, 8, 12, 14, -1, -1, -1}, - { 0, 2, 6, 8, 12, 14, -1, -1}, - { 4, 6, 8, 12, 14, -1, -1, -1}, - { 0, 4, 6, 8, 12, 14, -1, -1}, - { 2, 4, 6, 8, 12, 14, -1, -1}, - { 0, 2, 4, 6, 8, 12, 14, -1}, - {10, 12, 14, -1, -1, -1, -1, -1}, - { 0, 10, 12, 14, -1, -1, -1, -1}, - { 2, 10, 12, 14, -1, -1, -1, -1}, - { 0, 2, 10, 12, 14, -1, -1, -1}, - { 4, 10, 12, 14, -1, -1, -1, -1}, - { 0, 4, 10, 12, 14, -1, -1, -1}, - { 2, 4, 10, 12, 14, -1, -1, -1}, - { 0, 2, 4, 10, 12, 14, -1, -1}, - { 6, 10, 12, 14, -1, -1, -1, -1}, - { 0, 6, 10, 12, 14, -1, -1, -1}, - { 2, 6, 10, 12, 14, -1, -1, -1}, - { 0, 2, 6, 10, 12, 14, -1, -1}, - { 4, 6, 10, 12, 14, -1, -1, -1}, - { 0, 4, 6, 10, 12, 14, -1, -1}, - { 2, 4, 6, 10, 12, 14, -1, -1}, - { 0, 2, 4, 6, 10, 12, 14, -1}, - { 8, 10, 12, 14, -1, -1, -1, -1}, - { 0, 8, 10, 12, 14, -1, -1, -1}, - { 2, 8, 10, 12, 14, -1, -1, -1}, - { 0, 2, 8, 10, 12, 14, -1, -1}, - { 4, 8, 10, 12, 14, -1, -1, -1}, - { 0, 4, 8, 10, 12, 14, -1, -1}, - { 2, 4, 8, 10, 12, 14, -1, -1}, - { 0, 2, 4, 8, 10, 12, 14, -1}, - { 6, 8, 10, 12, 14, -1, -1, -1}, - { 0, 6, 8, 10, 12, 14, -1, -1}, - { 2, 6, 8, 10, 12, 14, -1, -1}, - { 0, 2, 6, 8, 10, 12, 14, -1}, - { 4, 6, 8, 10, 12, 14, -1, -1}, - { 0, 4, 6, 8, 10, 12, 14, -1}, - { 2, 4, 6, 8, 10, 12, 14, -1}, - { 0, 2, 4, 6, 8, 10, 12, 14} -}; -#endif - -#define _mm256_cmpge_epu16(a, b) _mm256_cmpeq_epi16(_mm256_max_epu16(a, b), a) -#define _mm_cmpge_epu16(a, b) _mm_cmpeq_epi16(_mm_max_epu16(a, b), a) - -unsigned int rej_uniform_avx(int16_t * restrict r, const uint8_t *buf) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - uint32_t good; -#ifdef BMI - uint64_t idx0, idx1, idx2, idx3; -#endif - const __m256i bound = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i ones = _mm256_set1_epi8(1); - const __m256i mask = _mm256_set1_epi16(0xFFF); - const __m256i idx8 = _mm256_set_epi8(15,14,14,13,12,11,11,10, - 9, 8, 8, 7, 6, 5, 5, 4, - 11,10,10, 9, 8, 7, 7, 6, - 5, 4, 4, 3, 2, 1, 1, 0); - __m256i f0, f1, g0, g1, g2, g3; - __m128i f, t, pilo, pihi; - - ctr = pos = 0; - while(ctr <= KYBER_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 56) { - f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); - f1 = _mm256_loadu_si256((__m256i *)&buf[pos+24]); - f0 = _mm256_permute4x64_epi64(f0, 0x94); - f1 = _mm256_permute4x64_epi64(f1, 0x94); - f0 = _mm256_shuffle_epi8(f0, idx8); - f1 = _mm256_shuffle_epi8(f1, idx8); - g0 = _mm256_srli_epi16(f0, 4); - g1 = _mm256_srli_epi16(f1, 4); - f0 = _mm256_blend_epi16(f0, g0, 0xAA); - f1 = _mm256_blend_epi16(f1, g1, 0xAA); - f0 = _mm256_and_si256(f0, mask); - f1 = _mm256_and_si256(f1, mask); - pos += 48; - - g0 = _mm256_cmpgt_epi16(bound, f0); - g1 = _mm256_cmpgt_epi16(bound, f1); - - g0 = _mm256_packs_epi16(g0, g1); - good = _mm256_movemask_epi8(g0); - -#ifdef BMI - idx0 = _pdep_u64(good >> 0, 0x0101010101010101); - idx1 = _pdep_u64(good >> 8, 0x0101010101010101); - idx2 = _pdep_u64(good >> 16, 0x0101010101010101); - idx3 = _pdep_u64(good >> 24, 0x0101010101010101); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - idx1 = (idx1 << 8) - idx1; - idx1 = _pext_u64(0x0E0C0A0806040200, idx1); - idx2 = (idx2 << 8) - idx2; - idx2 = _pext_u64(0x0E0C0A0806040200, idx2); - idx3 = (idx3 << 8) - idx3; - idx3 = _pext_u64(0x0E0C0A0806040200, idx3); - - g0 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx0)); - g1 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx1)); - g0 = _mm256_inserti128_si256(g0, _mm_cvtsi64_si128(idx2), 1); - g1 = _mm256_inserti128_si256(g1, _mm_cvtsi64_si128(idx3), 1); -#else - g0 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 0) & 0xFF])); - g1 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 8) & 0xFF])); - g0 = _mm256_inserti128_si256(g0, _mm_loadl_epi64((__m128i *)&idx[(good >> 16) & 0xFF]), 1); - g1 = _mm256_inserti128_si256(g1, _mm_loadl_epi64((__m128i *)&idx[(good >> 24) & 0xFF]), 1); -#endif - - g2 = _mm256_add_epi8(g0, ones); - g3 = _mm256_add_epi8(g1, ones); - g0 = _mm256_unpacklo_epi8(g0, g2); - g1 = _mm256_unpacklo_epi8(g1, g3); - - f0 = _mm256_shuffle_epi8(f0, g0); - f1 = _mm256_shuffle_epi8(f1, g1); - - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); - ctr += _mm_popcnt_u32((good >> 0) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); - ctr += _mm_popcnt_u32((good >> 16) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); - ctr += _mm_popcnt_u32((good >> 8) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); - ctr += _mm_popcnt_u32((good >> 24) & 0xFF); - } - - while(ctr <= KYBER_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 16) { - f = _mm_loadu_si128((__m128i *)&buf[pos]); - f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); - t = _mm_srli_epi16(f, 4); - f = _mm_blend_epi16(f, t, 0xAA); - f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); - pos += 12; - - t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); - good = _mm_movemask_epi8(t); - -#ifdef BMI - good &= 0x5555; - idx0 = _pdep_u64(good, 0x1111111111111111); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - pilo = _mm_cvtsi64_si128(idx0); -#else - good = _pext_u32(good, 0x5555); - pilo = _mm_loadl_epi64((__m128i *)&idx[good]); -#endif - - pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); - pilo = _mm_unpacklo_epi8(pilo, pihi); - f = _mm_shuffle_epi8(f, pilo); - _mm_storeu_si128((__m128i *)&r[ctr], f); - ctr += _mm_popcnt_u32(good); - } - - while(ctr < KYBER_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)); - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(val1 < KYBER_Q && ctr < KYBER_N) - r[ctr++] = val1; - } - - return ctr; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.h deleted file mode 100644 index 3be5e2192e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/rejsample.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef REJSAMPLE_H -#define REJSAMPLE_H - -#include -#include "params.h" -#include "symmetric.h" - -#define REJ_UNIFORM_AVX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -#define REJ_UNIFORM_AVX_BUFLEN (REJ_UNIFORM_AVX_NBLOCKS*XOF_BLOCKBYTES) - -#define rej_uniform_avx KYBER_NAMESPACE(rej_uniform_avx) -unsigned int rej_uniform_avx(int16_t *r, const uint8_t *buf); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric.h deleted file mode 100644 index e4941f7a86..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/symmetric.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" -#include "fips202x4.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/verify.c deleted file mode 100644 index 06243b837f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_avx2/verify.c +++ /dev/null @@ -1,83 +0,0 @@ -#include -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint64_t r; - __m256i f, g, h; - - h = _mm256_setzero_si256(); - for(i=0;i> 63; - return r; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t * restrict r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - __m256i xvec, rvec, bvec; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - bvec = _mm256_set1_epi64x(-(uint64_t)b); - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/api.h deleted file mode 100644 index 70d40f3f3e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_ref_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_ref_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_ref_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_ref_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_ref_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_ref_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_ref_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_ref_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_ref_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_ref_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_ref_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_ref_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_ref_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_ref_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_ref_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_ref_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_ref_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_ref_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.c deleted file mode 100644 index 1500ffea56..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.c +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: load32_littleendian -* -* Description: load 4 bytes into a 32-bit integer -* in little-endian order -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x -**************************************************/ -static uint32_t load32_littleendian(const uint8_t x[4]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - r |= (uint32_t)x[3] << 24; - return r; -} - -/************************************************* -* Name: load24_littleendian -* -* Description: load 3 bytes into a 32-bit integer -* in little-endian order. -* This function is only needed for Kyber-512 -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x (most significant byte is zero) -**************************************************/ -#if KYBER_ETA1 == 3 -static uint32_t load24_littleendian(const uint8_t x[3]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - return r; -} -#endif - - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -static void cbd2(poly *r, const uint8_t buf[2*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x55555555; - - for(j=0;j<8;j++) { - a = (d >> (4*j+0)) & 0x3; - b = (d >> (4*j+2)) & 0x3; - r->coeffs[8*i+j] = a - b; - } - } -} - -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3. -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -#if KYBER_ETA1 == 3 -static void cbd3(poly *r, const uint8_t buf[3*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x00249249; - d += (t>>2) & 0x00249249; - - for(j=0;j<4;j++) { - a = (d >> (6*j+0)) & 0x7; - b = (d >> (6*j+3)) & 0x7; - r->coeffs[4*i+j] = a - b; - } - } -} -#endif - -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.h deleted file mode 100644 index 7b677d745d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/cbd.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.c deleted file mode 100644 index 726cfa985d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/indcpa.c +++ /dev/null @@ -1,334 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk -* and the public seed used to generate the matrix A. -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output buffer -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos + 3 <= buflen) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if(XOF_BLOCKBYTES % 3) -#error "Implementation of gen_matrix assumes that XOF_BLOCKBYTES is a multiple of 3" -#endif - -#define GEN_MATRIX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -// Not static for benchmarking -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed) -{ - unsigned int ctr, i, j; - unsigned int buflen; - uint8_t buf[GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]; - xof_state state; - xof_init(&state, seed); - - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.c deleted file mode 100644 index 2f2eb10b2f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.c +++ /dev/null @@ -1,146 +0,0 @@ -#include -#include "params.h" -#include "ntt.h" -#include "reduce.h" - -/* Code to generate zetas and zetas_inv used in the number-theoretic transform: - -#define KYBER_ROOT_OF_UNITY 17 - -static const uint8_t tree[128] = { - 0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, - 4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124, - 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122, - 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, - 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, - 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, - 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, - 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127 -}; - -void init_ntt() { - unsigned int i; - int16_t tmp[128]; - - tmp[0] = MONT; - for(i=1;i<128;i++) - tmp[i] = fqmul(tmp[i-1],MONT*KYBER_ROOT_OF_UNITY % KYBER_Q); - - for(i=0;i<128;i++) { - zetas[i] = tmp[tree[i]]; - if(zetas[i] > KYBER_Q/2) - zetas[i] -= KYBER_Q; - if(zetas[i] < -KYBER_Q/2) - zetas[i] += KYBER_Q; - } -} -*/ - -const int16_t zetas[128] = { - -1044, -758, -359, -1517, 1493, 1422, 287, 202, - -171, 622, 1577, 182, 962, -1202, -1474, 1468, - 573, -1325, 264, 383, -829, 1458, -1602, -130, - -681, 1017, 732, 608, -1542, 411, -205, -1571, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -1103, 430, 555, 843, -1251, 871, 1550, 105, - 422, 587, 177, -235, -291, -460, 1574, 1653, - -246, 778, 1159, -147, -777, 1483, -602, 1119, - -1590, 644, -872, 349, 418, 329, -156, -75, - 817, 1097, 603, 610, 1322, -1285, -1465, 384, - -1215, -136, 1218, -1335, -874, 220, -1187, -1659, - -1185, -1530, -1278, 794, -1510, -854, -870, 478, - -108, -308, 996, 991, 958, -1460, 1522, 1628 -}; - -/************************************************* -* Name: fqmul -* -* Description: Multiplication followed by Montgomery reduction -* -* Arguments: - int16_t a: first factor -* - int16_t b: second factor -* -* Returns 16-bit integer congruent to a*b*R^{-1} mod q -**************************************************/ -static int16_t fqmul(int16_t a, int16_t b) { - return montgomery_reduce((int32_t)a*b); -} - -/************************************************* -* Name: ntt -* -* Description: Inplace number-theoretic transform (NTT) in Rq. -* input is in standard order, output is in bitreversed order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void ntt(int16_t r[256]) { - unsigned int len, start, j, k; - int16_t t, zeta; - - k = 1; - for(len = 128; len >= 2; len >>= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k++]; - for(j = start; j < start + len; j++) { - t = fqmul(zeta, r[j + len]); - r[j + len] = r[j] - t; - r[j] = r[j] + t; - } - } - } -} - -/************************************************* -* Name: invntt_tomont -* -* Description: Inplace inverse number-theoretic transform in Rq and -* multiplication by Montgomery factor 2^16. -* Input is in bitreversed order, output is in standard order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void invntt(int16_t r[256]) { - unsigned int start, len, j, k; - int16_t t, zeta; - const int16_t f = 1441; // mont^2/128 - - k = 127; - for(len = 2; len <= 128; len <<= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k--]; - for(j = start; j < start + len; j++) { - t = r[j]; - r[j] = barrett_reduce(t + r[j + len]); - r[j + len] = r[j + len] - t; - r[j + len] = fqmul(zeta, r[j + len]); - } - } - } - - for(j = 0; j < 256; j++) - r[j] = fqmul(r[j], f); -} - -/************************************************* -* Name: basemul -* -* Description: Multiplication of polynomials in Zq[X]/(X^2-zeta) -* used for multiplication of elements in Rq in NTT domain -* -* Arguments: - int16_t r[2]: pointer to the output polynomial -* - const int16_t a[2]: pointer to the first factor -* - const int16_t b[2]: pointer to the second factor -* - int16_t zeta: integer defining the reduction polynomial -**************************************************/ -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta) -{ - r[0] = fqmul(a[1], b[1]); - r[0] = fqmul(r[0], zeta); - r[0] += fqmul(a[0], b[0]); - r[1] = fqmul(a[0], b[1]); - r[1] += fqmul(a[1], b[0]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.h deleted file mode 100644 index 227ea74f08..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/ntt.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include "params.h" - -#define zetas KYBER_NAMESPACE(zetas) -extern const int16_t zetas[128]; - -#define ntt KYBER_NAMESPACE(ntt) -void ntt(int16_t poly[256]); - -#define invntt KYBER_NAMESPACE(invntt) -void invntt(int16_t poly[256]); - -#define basemul KYBER_NAMESPACE(basemul) -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/params.h deleted file mode 100644 index fb4190b311..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/params.h +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_ref_##s -#elif (KYBER_K == 3) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_ref_##s -#elif (KYBER_K == 4) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_ref_##s -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.c deleted file mode 100644 index cbd3abfb54..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.c +++ /dev/null @@ -1,360 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" -#include "verify.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a) -{ - unsigned int i,j; - int16_t u; - uint32_t d0; - uint8_t t[8]; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 15; */ - d0 = u << 4; - d0 += 1665; - d0 *= 80635; - d0 >>= 28; - t[j] = d0 & 0xf; - } - - r[0] = t[0] | (t[1] << 4); - r[1] = t[2] | (t[3] << 4); - r[2] = t[4] | (t[5] << 4); - r[3] = t[6] | (t[7] << 4); - r += 4; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 31; */ - d0 = u << 5; - d0 += 1664; - d0 *= 40318; - d0 >>= 27; - t[j] = d0 & 0x1f; - } - - r[0] = (t[0] >> 0) | (t[1] << 5); - r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7); - r[2] = (t[3] >> 1) | (t[4] << 4); - r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6); - r[4] = (t[6] >> 2) | (t[7] << 3); - r += 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_decompress -* -* Description: De-serialization and subsequent decompression of a polynomial; -* approximate inverse of poly_compress -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYCOMPRESSEDBYTES bytes) -**************************************************/ -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]) -{ - unsigned int i; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - for(i=0;icoeffs[2*i+0] = (((uint16_t)(a[0] & 15)*KYBER_Q) + 8) >> 4; - r->coeffs[2*i+1] = (((uint16_t)(a[0] >> 4)*KYBER_Q) + 8) >> 4; - a += 1; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - unsigned int j; - uint8_t t[8]; - for(i=0;i> 0); - t[1] = (a[0] >> 5) | (a[1] << 3); - t[2] = (a[1] >> 2); - t[3] = (a[1] >> 7) | (a[2] << 1); - t[4] = (a[2] >> 4) | (a[3] << 4); - t[5] = (a[3] >> 1); - t[6] = (a[3] >> 6) | (a[4] << 2); - t[7] = (a[4] >> 3); - a += 5; - - for(j=0;j<8;j++) - r->coeffs[8*i+j] = ((uint32_t)(t[j] & 31)*KYBER_Q + 16) >> 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - unsigned int i; - uint16_t t0, t1; - - for(i=0;icoeffs[2*i]; - t0 += ((int16_t)t0 >> 15) & KYBER_Q; - t1 = a->coeffs[2*i+1]; - t1 += ((int16_t)t1 >> 15) & KYBER_Q; - r[3*i+0] = (t0 >> 0); - r[3*i+1] = (t0 >> 8) | (t1 << 4); - r[3*i+2] = (t1 >> 4); - } -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - unsigned int i; - for(i=0;icoeffs[2*i] = ((a[3*i+0] >> 0) | ((uint16_t)a[3*i+1] << 8)) & 0xFFF; - r->coeffs[2*i+1] = ((a[3*i+1] >> 4) | ((uint16_t)a[3*i+2] << 4)) & 0xFFF; - } -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ - unsigned int i,j; - -#if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) -#error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" -#endif - - for(i=0;icoeffs[8*i+j] = 0; - cmov_int16(r->coeffs+8*i+j, ((KYBER_Q+1)/2), (msg[i] >> j)&1); - } - } -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message -* -* Arguments: - uint8_t *msg: pointer to output message -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) -{ - unsigned int i,j; - uint32_t t; - - for(i=0;icoeffs[8*i+j]; - // t += ((int16_t)t >> 15) & KYBER_Q; - // t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1; - t <<= 1; - t += 1665; - t *= 80635; - t >>= 28; - t &= 1; - msg[i] |= t << j; - } - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA1*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta1(r, buf); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA2*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta2(r, buf); -} - - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place; -* inputs assumed to be in normal order, output in bitreversed order -* -* Arguments: - uint16_t *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt(r->coeffs); - poly_reduce(r); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* inputs assumed to be in bitreversed order, output in normal order -* -* Arguments: - uint16_t *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt(r->coeffs); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[4*i], &a->coeffs[4*i], &b->coeffs[4*i], zetas[64+i]); - basemul(&r->coeffs[4*i+2], &a->coeffs[4*i+2], &b->coeffs[4*i+2], -zetas[64+i]); - } -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - unsigned int i; - const int16_t f = (1ULL << 32) % KYBER_Q; - for(i=0;icoeffs[i] = montgomery_reduce((int32_t)r->coeffs[i]*f); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - unsigned int i; - for(i=0;icoeffs[i] = barrett_reduce(r->coeffs[i]); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] + b->coeffs[i]; -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] - b->coeffs[i]; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.h deleted file mode 100644 index 9a99c7cdad..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/poly.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "params.h" - -/* - * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial - * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] - */ -typedef struct{ - int16_t coeffs[KYBER_N]; -} poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.c deleted file mode 100644 index 669f6a5f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.c +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "polyvec.h" - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a) -{ - unsigned int i,j,k; - uint64_t d0; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;ivec[i].coeffs[8*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; */ - d0 = t[k]; - d0 <<= 11; - d0 += 1664; - d0 *= 645084; - d0 >>= 31; - t[k] = d0 & 0x7ff; - } - - r[ 0] = (t[0] >> 0); - r[ 1] = (t[0] >> 8) | (t[1] << 3); - r[ 2] = (t[1] >> 5) | (t[2] << 6); - r[ 3] = (t[2] >> 2); - r[ 4] = (t[2] >> 10) | (t[3] << 1); - r[ 5] = (t[3] >> 7) | (t[4] << 4); - r[ 6] = (t[4] >> 4) | (t[5] << 7); - r[ 7] = (t[5] >> 1); - r[ 8] = (t[5] >> 9) | (t[6] << 2); - r[ 9] = (t[6] >> 6) | (t[7] << 5); - r[10] = (t[7] >> 3); - r += 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;ivec[i].coeffs[4*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; */ - d0 = t[k]; - d0 <<= 10; - d0 += 1665; - d0 *= 1290167; - d0 >>= 32; - t[k] = d0 & 0x3ff; - } - - r[0] = (t[0] >> 0); - r[1] = (t[0] >> 8) | (t[1] << 2); - r[2] = (t[1] >> 6) | (t[2] << 4); - r[3] = (t[2] >> 4) | (t[3] << 6); - r[4] = (t[3] >> 2); - r += 5; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]) -{ - unsigned int i,j,k; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;i> 0) | ((uint16_t)a[ 1] << 8); - t[1] = (a[1] >> 3) | ((uint16_t)a[ 2] << 5); - t[2] = (a[2] >> 6) | ((uint16_t)a[ 3] << 2) | ((uint16_t)a[4] << 10); - t[3] = (a[4] >> 1) | ((uint16_t)a[ 5] << 7); - t[4] = (a[5] >> 4) | ((uint16_t)a[ 6] << 4); - t[5] = (a[6] >> 7) | ((uint16_t)a[ 7] << 1) | ((uint16_t)a[8] << 9); - t[6] = (a[8] >> 2) | ((uint16_t)a[ 9] << 6); - t[7] = (a[9] >> 5) | ((uint16_t)a[10] << 3); - a += 11; - - for(k=0;k<8;k++) - r->vec[i].coeffs[8*j+k] = ((uint32_t)(t[k] & 0x7FF)*KYBER_Q + 1024) >> 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;i> 0) | ((uint16_t)a[1] << 8); - t[1] = (a[1] >> 2) | ((uint16_t)a[2] << 6); - t[2] = (a[2] >> 4) | ((uint16_t)a[3] << 4); - t[3] = (a[3] >> 6) | ((uint16_t)a[4] << 2); - a += 5; - - for(k=0;k<4;k++) - r->vec[i].coeffs[4*j+k] = ((uint32_t)(t[k] & 0x3FF)*KYBER_Q + 512) >> 10; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements of a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly t; - - poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]); - for(i=1;ivec[i], &b->vec[i]); - poly_add(r, r, &t); - } - - poly_reduce(r); -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.h deleted file mode 100644 index 57b605494e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.c deleted file mode 100644 index 9d8e7edf83..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.c +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include "params.h" -#include "reduce.h" - -/************************************************* -* Name: montgomery_reduce -* -* Description: Montgomery reduction; given a 32-bit integer a, computes -* 16-bit integer congruent to a * R^-1 mod q, where R=2^16 -* -* Arguments: - int32_t a: input integer to be reduced; -* has to be in {-q2^15,...,q2^15-1} -* -* Returns: integer in {-q+1,...,q-1} congruent to a * R^-1 modulo q. -**************************************************/ -int16_t montgomery_reduce(int32_t a) -{ - int16_t t; - - t = (int16_t)a*QINV; - t = (a - (int32_t)t*KYBER_Q) >> 16; - return t; -} - -/************************************************* -* Name: barrett_reduce -* -* Description: Barrett reduction; given a 16-bit integer a, computes -* centered representative congruent to a mod q in {-(q-1)/2,...,(q-1)/2} -* -* Arguments: - int16_t a: input integer to be reduced -* -* Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. -**************************************************/ -int16_t barrett_reduce(int16_t a) { - int16_t t; - const int16_t v = ((1<<26) + KYBER_Q/2)/KYBER_Q; - - t = ((int32_t)v*a + (1<<25)) >> 26; - t *= KYBER_Q; - return a - t; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.h deleted file mode 100644 index c1bc1e4c7b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/reduce.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include -#include "params.h" - -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 - -#define montgomery_reduce KYBER_NAMESPACE(montgomery_reduce) -int16_t montgomery_reduce(int32_t a); - -#define barrett_reduce KYBER_NAMESPACE(barrett_reduce) -int16_t barrett_reduce(int16_t a); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric.h deleted file mode 100644 index 2acc66f98d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/symmetric.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_init(STATE, SEED) shake128_inc_init(STATE) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define xof_release(STATE) shake128_inc_ctx_release(STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/verify.c deleted file mode 100644 index 914ccd448f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-512_ref/verify.c +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint8_t r = 0; - - for(i=0;i> 63; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - b = -b; - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/align.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/align.h deleted file mode 100644 index 3463866f37..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/align.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef ALIGN_H -#define ALIGN_H - -#include -#include - -#define ALIGNED_UINT8(N) \ - union { \ - uint8_t coeffs[N]; \ - __m256i vec[(N+31)/32]; \ - } - -#define ALIGNED_INT16(N) \ - union { \ - int16_t coeffs[N]; \ - __m256i vec[(N+15)/16]; \ - } - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/api.h deleted file mode 100644 index a154e80f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_avx2_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_avx2_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_avx2_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_avx2_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_avx2_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_avx2_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_avx2_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_avx2_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_avx2_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_avx2_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_avx2_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_avx2_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_avx2_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_avx2_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_avx2_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_avx2_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_avx2_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_avx2_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_avx2_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_avx2_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_avx2_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_avx2_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.c deleted file mode 100644 index dad473c79e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.c +++ /dev/null @@ -1,144 +0,0 @@ -#include -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd2(poly * restrict r, const __m256i buf[2*KYBER_N/128]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask55 = _mm256_set1_epi32(0x55555555); - const __m256i mask33 = _mm256_set1_epi32(0x33333333); - const __m256i mask03 = _mm256_set1_epi32(0x03030303); - const __m256i mask0F = _mm256_set1_epi32(0x0F0F0F0F); - - for(i = 0; i < KYBER_N/64; i++) { - f0 = _mm256_load_si256(&buf[i]); - - f1 = _mm256_srli_epi16(f0, 1); - f0 = _mm256_and_si256(mask55, f0); - f1 = _mm256_and_si256(mask55, f1); - f0 = _mm256_add_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 2); - f0 = _mm256_and_si256(mask33, f0); - f1 = _mm256_and_si256(mask33, f1); - f0 = _mm256_add_epi8(f0, mask33); - f0 = _mm256_sub_epi8(f0, f1); - - f1 = _mm256_srli_epi16(f0, 4); - f0 = _mm256_and_si256(mask0F, f0); - f1 = _mm256_and_si256(mask0F, f1); - f0 = _mm256_sub_epi8(f0, mask03); - f1 = _mm256_sub_epi8(f1, mask03); - - f2 = _mm256_unpacklo_epi8(f0, f1); - f3 = _mm256_unpackhi_epi8(f0, f1); - - f0 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f2)); - f1 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f2,1)); - f2 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(f3)); - f3 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(f3,1)); - - _mm256_store_si256(&r->vec[4*i+0], f0); - _mm256_store_si256(&r->vec[4*i+1], f2); - _mm256_store_si256(&r->vec[4*i+2], f1); - _mm256_store_si256(&r->vec[4*i+3], f3); - } -} - -#if KYBER_ETA1 == 3 -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3 -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const __m256i *buf: pointer to aligned input byte array -**************************************************/ -static void cbd3(poly * restrict r, const uint8_t buf[3*KYBER_N/4+8]) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i mask249 = _mm256_set1_epi32(0x249249); - const __m256i mask6DB = _mm256_set1_epi32(0x6DB6DB); - const __m256i mask07 = _mm256_set1_epi32(7); - const __m256i mask70 = _mm256_set1_epi32(7 << 16); - const __m256i mask3 = _mm256_set1_epi16(3); - const __m256i shufbidx = _mm256_set_epi8(-1,15,14,13,-1,12,11,10,-1, 9, 8, 7,-1, 6, 5, 4, - -1,11,10, 9,-1, 8, 7, 6,-1, 5, 4, 3,-1, 2, 1, 0); - - for(i = 0; i < KYBER_N/32; i++) { - f0 = _mm256_loadu_si256((__m256i *)&buf[24*i]); - f0 = _mm256_permute4x64_epi64(f0,0x94); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - - f1 = _mm256_srli_epi32(f0,1); - f2 = _mm256_srli_epi32(f0,2); - f0 = _mm256_and_si256(mask249,f0); - f1 = _mm256_and_si256(mask249,f1); - f2 = _mm256_and_si256(mask249,f2); - f0 = _mm256_add_epi32(f0,f1); - f0 = _mm256_add_epi32(f0,f2); - - f1 = _mm256_srli_epi32(f0,3); - f0 = _mm256_add_epi32(f0,mask6DB); - f0 = _mm256_sub_epi32(f0,f1); - - f1 = _mm256_slli_epi32(f0,10); - f2 = _mm256_srli_epi32(f0,12); - f3 = _mm256_srli_epi32(f0, 2); - f0 = _mm256_and_si256(f0,mask07); - f1 = _mm256_and_si256(f1,mask70); - f2 = _mm256_and_si256(f2,mask07); - f3 = _mm256_and_si256(f3,mask70); - f0 = _mm256_add_epi16(f0,f1); - f1 = _mm256_add_epi16(f2,f3); - f0 = _mm256_sub_epi16(f0,mask3); - f1 = _mm256_sub_epi16(f1,mask3); - - f2 = _mm256_unpacklo_epi32(f0,f1); - f3 = _mm256_unpackhi_epi32(f0,f1); - - f0 = _mm256_permute2x128_si256(f2,f3,0x20); - f1 = _mm256_permute2x128_si256(f2,f3,0x31); - - _mm256_store_si256(&r->vec[2*i+0], f0); - _mm256_store_si256(&r->vec[2*i+1], f1); - } -} -#endif - -/* buf 32 bytes longer for cbd3 */ -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, (uint8_t *)buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.h deleted file mode 100644 index 05788e06b4..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/cbd.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const __m256i buf[KYBER_ETA1*KYBER_N/128+1]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const __m256i buf[KYBER_ETA2*KYBER_N/128]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.c deleted file mode 100644 index 84e596893d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.c +++ /dev/null @@ -1,121 +0,0 @@ -#include "align.h" -#include "params.h" -#include "consts.h" - -#define Q KYBER_Q -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 -#define V 20159 // floor(2^26/q + 0.5) -#define FHI 1441 // mont^2/128 -#define FLO -10079 // qinv*FHI -#define MONTSQHI 1353 // mont^2 -#define MONTSQLO 20553 // qinv*MONTSQHI -#define MASK 4095 -#define SHIFT 32 - -const qdata_t qdata = {{ -#define _16XQ 0 - Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, Q, - -#define _16XQINV 16 - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - QINV, QINV, QINV, QINV, QINV, QINV, QINV, QINV, - -#define _16XV 32 - V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, - -#define _16XFLO 48 - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - FLO, FLO, FLO, FLO, FLO, FLO, FLO, FLO, - -#define _16XFHI 64 - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - FHI, FHI, FHI, FHI, FHI, FHI, FHI, FHI, - -#define _16XMONTSQLO 80 - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - MONTSQLO, MONTSQLO, MONTSQLO, MONTSQLO, - -#define _16XMONTSQHI 96 - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - MONTSQHI, MONTSQHI, MONTSQHI, MONTSQHI, - -#define _16XMASK 112 - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - MASK, MASK, MASK, MASK, MASK, MASK, MASK, MASK, - -#define _REVIDXB 128 - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - 3854, 3340, 2826, 2312, 1798, 1284, 770, 256, - -#define _REVIDXD 144 - 7, 0, 6, 0, 5, 0, 4, 0, 3, 0, 2, 0, 1, 0, 0, 0, - -#define _ZETAS_EXP 160 - 31498, 31498, 31498, 31498, -758, -758, -758, -758, - 5237, 5237, 5237, 5237, 1397, 1397, 1397, 1397, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - 14745, 14745, 14745, 14745, 14745, 14745, 14745, 14745, - -359, -359, -359, -359, -359, -359, -359, -359, - -359, -359, -359, -359, -359, -359, -359, -359, - 13525, 13525, 13525, 13525, 13525, 13525, 13525, 13525, - -12402, -12402, -12402, -12402, -12402, -12402, -12402, -12402, - 1493, 1493, 1493, 1493, 1493, 1493, 1493, 1493, - 1422, 1422, 1422, 1422, 1422, 1422, 1422, 1422, - -20907, -20907, -20907, -20907, 27758, 27758, 27758, 27758, - -3799, -3799, -3799, -3799, -15690, -15690, -15690, -15690, - -171, -171, -171, -171, 622, 622, 622, 622, - 1577, 1577, 1577, 1577, 182, 182, 182, 182, - -5827, -5827, 17363, 17363, -26360, -26360, -29057, -29057, - 5571, 5571, -1102, -1102, 21438, 21438, -26242, -26242, - 573, 573, -1325, -1325, 264, 264, 383, 383, - -829, -829, 1458, 1458, -1602, -1602, -130, -130, - -5689, -6516, 1496, 30967, -23565, 20179, 20710, 25080, - -12796, 26616, 16064, -12442, 9134, -650, -25986, 27837, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -335, -11477, -32227, 20494, -27738, 945, -14883, 6182, - 32010, 10631, 29175, -28762, -18486, 17560, -14430, -5276, - -1103, 555, -1251, 1550, 422, 177, -291, 1574, - -246, 1159, -777, -602, -1590, -872, 418, -156, - 11182, 13387, -14233, -21655, 13131, -4587, 23092, 5493, - -32502, 30317, -18741, 12639, 20100, 18525, 19529, -12619, - 430, 843, 871, 105, 587, -235, -460, 1653, - 778, -147, 1483, 1119, 644, 349, 329, -75, - 787, 787, 787, 787, 787, 787, 787, 787, - 787, 787, 787, 787, 787, 787, 787, 787, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - -1517, -1517, -1517, -1517, -1517, -1517, -1517, -1517, - 28191, 28191, 28191, 28191, 28191, 28191, 28191, 28191, - -16694, -16694, -16694, -16694, -16694, -16694, -16694, -16694, - 287, 287, 287, 287, 287, 287, 287, 287, - 202, 202, 202, 202, 202, 202, 202, 202, - 10690, 10690, 10690, 10690, 1358, 1358, 1358, 1358, - -11202, -11202, -11202, -11202, 31164, 31164, 31164, 31164, - 962, 962, 962, 962, -1202, -1202, -1202, -1202, - -1474, -1474, -1474, -1474, 1468, 1468, 1468, 1468, - -28073, -28073, 24313, 24313, -10532, -10532, 8800, 8800, - 18426, 18426, 8859, 8859, 26675, 26675, -16163, -16163, - -681, -681, 1017, 1017, 732, 732, 608, 608, - -1542, -1542, 411, 411, -205, -205, -1571, -1571, - 19883, -28250, -15887, -8898, -28309, 9075, -30199, 18249, - 13426, 14017, -29156, -12757, 16832, 4311, -24155, -17915, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -31183, 25435, -7382, 24391, -20927, 10946, 24214, 16989, - 10335, -7934, -22502, 10906, 31636, 28644, 23998, -17422, - 817, 603, 1322, -1465, -1215, 1218, -874, -1187, - -1185, -1278, -1510, -870, -108, 996, 958, 1522, - 20297, 2146, 15355, -32384, -6280, -14903, -11044, 14469, - -21498, -20198, 23210, -17442, -23860, -20257, 7756, 23132, - 1097, 610, -1285, 384, -136, -1335, 220, -1659, - -1530, 794, -854, 478, -308, 991, -1460, 1628, - -#define _16XSHIFT 624 - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, - SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT, SHIFT -}}; diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.h deleted file mode 100644 index f95899cd8e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/consts.h +++ /dev/null @@ -1,43 +0,0 @@ -#ifndef CONSTS_H -#define CONSTS_H - -#include "params.h" - -#define _16XQ 0 -#define _16XQINV 16 -#define _16XV 32 -#define _16XFLO 48 -#define _16XFHI 64 -#define _16XMONTSQLO 80 -#define _16XMONTSQHI 96 -#define _16XMASK 112 -#define _REVIDXB 128 -#define _REVIDXD 144 -#define _ZETAS_EXP 160 -#define _16XSHIFT 624 - -/* The C ABI on MacOS exports all symbols with a leading - * underscore. This means that any symbols we refer to from - * C files (functions) can't be found, and all symbols we - * refer to from ASM also can't be found. - * - * This define helps us get around this - */ -#ifdef __ASSEMBLER__ -#if defined(__WIN32__) || defined(__APPLE__) -#define decorate(s) _##s -#define cdecl2(s) decorate(s) -#define cdecl(s) cdecl2(KYBER_NAMESPACE(##s)) -#else -#define cdecl(s) KYBER_NAMESPACE(##s) -#endif -#endif - -#ifndef __ASSEMBLER__ -#include "align.h" -typedef ALIGNED_INT16(640) qdata_t; -#define qdata KYBER_NAMESPACE(qdata) -extern const qdata_t qdata; -#endif - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.c deleted file mode 100644 index c4b2b3a89f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/indcpa.c +++ /dev/null @@ -1,568 +0,0 @@ -#include -#include -#include -#include -#include "align.h" -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "cbd.h" -#include "rejsample.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk and the -* public seed used to generate the matrix A. -* The polynomial coefficients in pk are assumed to -* lie in the invertal [0,q], i.e. pk must be reduced -* by polyvec_reduce(). -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key. -* The polynomial coefficients in sk are assumed to -* lie in the invertal [0,q], i.e. sk must be reduced -* by polyvec_reduce(). -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v. -* The polynomial coefficients in b and v are assumed to -* lie in the invertal [0,q], i.e. b and v must be reduced -* by polyvec_reduce() and poly_reduce(), respectively. -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output array -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos <= buflen - 3) { // buflen is always at least 3 - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if KYBER_K == 2 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 1; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 1; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[1].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[1].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[1].vec[0]); - poly_nttunpack(&a[1].vec[1]); - shake128x4_inc_ctx_release(&state); -} -#elif KYBER_K == 3 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128incctx state1x; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 0; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 0; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 0; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 0; - buf[3].coeffs[33] = 1; - } - - shake128x4_inc_init(&state); - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[0].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[1].vec[0].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[0].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[1].vec[0].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[0].vec[0]); - poly_nttunpack(&a[0].vec[1]); - poly_nttunpack(&a[0].vec[2]); - poly_nttunpack(&a[1].vec[0]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = 2; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = 0; - buf[3].coeffs[32] = 2; - buf[3].coeffs[33] = 1; - } - else { - buf[0].coeffs[32] = 1; - buf[0].coeffs[33] = 1; - buf[1].coeffs[32] = 2; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = 0; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = 1; - buf[3].coeffs[33] = 2; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[1].vec[1].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[1].vec[2].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[2].vec[0].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[2].vec[1].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[1].vec[1].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[1].vec[2].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[2].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[2].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - shake128x4_inc_ctx_release(&state); - - poly_nttunpack(&a[1].vec[1]); - poly_nttunpack(&a[1].vec[2]); - poly_nttunpack(&a[2].vec[0]); - poly_nttunpack(&a[2].vec[1]); - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - buf[0].coeffs[32] = 2; - buf[0].coeffs[33] = 2; - - shake128_inc_init(&state1x); - shake128_absorb_once(&state1x, buf[0].coeffs, 34); - shake128_squeezeblocks(buf[0].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state1x); - ctr0 = rej_uniform_avx(a[2].vec[2].coeffs, buf[0].coeffs); - while(ctr0 < KYBER_N) { - shake128_squeezeblocks(buf[0].coeffs, 1, &state1x); - ctr0 += rej_uniform(a[2].vec[2].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - } - shake128_inc_ctx_release(&state1x); - - poly_nttunpack(&a[2].vec[2]); -} -#elif KYBER_K == 4 -void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed) -{ - unsigned int i, ctr0, ctr1, ctr2, ctr3; - ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4]; - __m256i f; - shake128x4incctx state; - shake128x4_inc_init(&state); - - for(i=0;i<4;i++) { - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - if(transposed) { - buf[0].coeffs[32] = i; - buf[0].coeffs[33] = 0; - buf[1].coeffs[32] = i; - buf[1].coeffs[33] = 1; - buf[2].coeffs[32] = i; - buf[2].coeffs[33] = 2; - buf[3].coeffs[32] = i; - buf[3].coeffs[33] = 3; - } - else { - buf[0].coeffs[32] = 0; - buf[0].coeffs[33] = i; - buf[1].coeffs[32] = 1; - buf[1].coeffs[33] = i; - buf[2].coeffs[32] = 2; - buf[2].coeffs[33] = i; - buf[3].coeffs[32] = 3; - buf[3].coeffs[33] = i; - } - - shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34); - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state); - - ctr0 = rej_uniform_avx(a[i].vec[0].coeffs, buf[0].coeffs); - ctr1 = rej_uniform_avx(a[i].vec[1].coeffs, buf[1].coeffs); - ctr2 = rej_uniform_avx(a[i].vec[2].coeffs, buf[2].coeffs); - ctr3 = rej_uniform_avx(a[i].vec[3].coeffs, buf[3].coeffs); - - while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) { - shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state); - - ctr0 += rej_uniform(a[i].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE); - ctr1 += rej_uniform(a[i].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE); - ctr2 += rej_uniform(a[i].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE); - ctr3 += rej_uniform(a[i].vec[3].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE); - } - - poly_nttunpack(&a[i].vec[0]); - poly_nttunpack(&a[i].vec[1]); - poly_nttunpack(&a[i].vec[2]); - poly_nttunpack(&a[i].vec[3]); - } - shake128x4_inc_ctx_release(&state); -} -#endif - -/************************************************* -* Name: indcpa_keypair_derand -* -* Description: Generates public and private key for the CPA-secure -* public-key encryption scheme underlying Kyber -* -* Arguments: - uint8_t *pk: pointer to output public key -* (of length KYBER_INDCPA_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (of length KYBER_INDCPA_SECRETKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (of length KYBER_SYMBYTES bytes) -**************************************************/ -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]) -{ - unsigned int i; - uint8_t buf[2*KYBER_SYMBYTES]; - const uint8_t *publicseed = buf; - const uint8_t *noiseseed = buf + KYBER_SYMBYTES; - polyvec a[KYBER_K], e, pkpv, skpv; - - memcpy(buf, coins, KYBER_SYMBYTES); - buf[KYBER_SYMBYTES] = KYBER_K; - hash_g(buf, buf, KYBER_SYMBYTES+1); - - gen_a(a, publicseed); - -#if KYBER_K == 2 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, e.vec+0, e.vec+1, noiseseed, 0, 1, 2, 3); -#elif KYBER_K == 3 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, e.vec+0, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+1, e.vec+2, pkpv.vec+0, pkpv.vec+1, noiseseed, 4, 5, 6, 7); -#elif KYBER_K == 4 - poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, skpv.vec+3, noiseseed, 0, 1, 2, 3); - poly_getnoise_eta1_4x(e.vec+0, e.vec+1, e.vec+2, e.vec+3, noiseseed, 4, 5, 6, 7); -#endif - - polyvec_ntt(&skpv); - polyvec_reduce(&skpv); - polyvec_ntt(&e); - - // matrix-vector multiplication - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/invntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/invntt.S deleted file mode 100644 index 76d4189996..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/invntt.S +++ /dev/null @@ -1,193 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" -.include "fq.inc" - -.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 -vpsubw %ymm\rl0,%ymm\rh0,%ymm12 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 -vpsubw %ymm\rl1,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl0,%ymm12,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 -vpsubw %ymm\rl2,%ymm\rh2,%ymm14 - -vpmullw %ymm\zl0,%ymm13,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 -vpsubw %ymm\rl3,%ymm\rh3,%ymm15 - -vpmullw %ymm\zl1,%ymm14,%ymm\rh2 -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 -vpmullw %ymm\zl1,%ymm15,%ymm\rh3 - -vpmulhw %ymm\zh0,%ymm12,%ymm12 -vpmulhw %ymm\zh0,%ymm13,%ymm13 - -vpmulhw %ymm\zh1,%ymm14,%ymm14 -vpmulhw %ymm\zh1,%ymm15,%ymm15 - -vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 - -vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 - -# - -# - -vpsubw %ymm\rh0,%ymm12,%ymm\rh0 - -vpsubw %ymm\rh1,%ymm13,%ymm\rh1 - -vpsubw %ymm\rh2,%ymm14,%ymm\rh2 -vpsubw %ymm\rh3,%ymm15,%ymm\rh3 -.endm - -.macro intt_levels0t5 off -/* level 0 */ -vmovdqa _16XFLO*2(%rsi),%ymm2 -vmovdqa _16XFHI*2(%rsi),%ymm3 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -fqmulprecomp 2,3,4 -fqmulprecomp 2,3,6 -fqmulprecomp 2,3,5 -fqmulprecomp 2,3,7 - -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 - -fqmulprecomp 2,3,8 -fqmulprecomp 2,3,10 -fqmulprecomp 2,3,9 -fqmulprecomp 2,3,11 - -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+192)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm12 -vpshufb %ymm12,%ymm15,%ymm15 -vpshufb %ymm12,%ymm1,%ymm1 -vpshufb %ymm12,%ymm2,%ymm2 -vpshufb %ymm12,%ymm3,%ymm3 - -butterfly 4,5,8,9,6,7,10,11,15,1,2,3 - -/* level 1 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3 -vmovdqa _REVIDXB*2(%rsi),%ymm1 -vpshufb %ymm1,%ymm2,%ymm2 -vpshufb %ymm1,%ymm3,%ymm3 - -butterfly 4,5,6,7,8,9,10,11,2,2,3,3 - -shuffle1 4,5,3,5 -shuffle1 6,7,4,7 -shuffle1 8,9,6,9 -shuffle1 10,11,8,11 - -/* level 2 */ -vmovdqa _REVIDXD*2(%rsi),%ymm12 -vpermd (_ZETAS_EXP+(1-\off)*224+112)*2(%rsi),%ymm12,%ymm2 -vpermd (_ZETAS_EXP+(1-\off)*224+128)*2(%rsi),%ymm12,%ymm10 - -butterfly 3,4,6,8,5,7,9,11,2,2,10,10 - -vmovdqa _16XV*2(%rsi),%ymm1 -red16 3 - -shuffle2 3,4,10,4 -shuffle2 6,8,3,8 -shuffle2 5,7,6,7 -shuffle2 9,11,5,11 - -/* level 3 */ -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+80)*2(%rsi),%ymm2 -vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+96)*2(%rsi),%ymm9 - -butterfly 10,3,6,5,4,8,7,11,2,2,9,9 - -shuffle4 10,3,9,3 -shuffle4 6,5,10,5 -shuffle4 4,8,6,8 -shuffle4 7,11,4,11 - -/* level 4 */ -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+48)*2(%rsi),%ymm2 -vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+64)*2(%rsi),%ymm7 - -butterfly 9,10,6,4,3,5,8,11,2,2,7,7 - -red16 9 - -shuffle8 9,10,7,10 -shuffle8 6,4,9,4 -shuffle8 3,5,6,5 -shuffle8 8,11,3,11 - -/* level 5 */ -vmovdqa (_ZETAS_EXP+(1-\off)*224+16)*2(%rsi),%ymm2 -vmovdqa (_ZETAS_EXP+(1-\off)*224+32)*2(%rsi),%ymm8 - -butterfly 7,9,6,3,10,4,5,11,2,2,8,8 - -vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.macro intt_level6 off -/* level 6 */ -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm2 - -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm3 - -butterfly 4,5,6,7,8,9,10,11 - -.if \off == 0 -red16 4 -.endif - -vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm7,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.text -.global cdecl(invntt_avx) -cdecl(invntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -intt_levels0t5 0 -intt_levels0t5 1 - -intt_level6 0 -intt_level6 1 -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.S b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.S deleted file mode 100644 index 0ce7b41297..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.S +++ /dev/null @@ -1,189 +0,0 @@ -#include "consts.h" -.include "shuffle.inc" - -.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 -vpmullw %ymm\zl0,%ymm\rh0,%ymm12 -vpmullw %ymm\zl0,%ymm\rh1,%ymm13 - -vpmullw %ymm\zl1,%ymm\rh2,%ymm14 -vpmullw %ymm\zl1,%ymm\rh3,%ymm15 - -vpmulhw %ymm\zh0,%ymm\rh0,%ymm\rh0 -vpmulhw %ymm\zh0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 -.endm - -.macro reduce -vpmulhw %ymm0,%ymm12,%ymm12 -vpmulhw %ymm0,%ymm13,%ymm13 - -vpmulhw %ymm0,%ymm14,%ymm14 -vpmulhw %ymm0,%ymm15,%ymm15 -.endm - -.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln -vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 - -vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 -vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 - -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 -vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 - -vpsubw %ymm12,%ymm\rln,%ymm\rln -vpaddw %ymm12,%ymm\rh0,%ymm\rh0 -vpsubw %ymm13,%ymm\rl0,%ymm\rl0 - -vpaddw %ymm13,%ymm\rh1,%ymm\rh1 -vpsubw %ymm14,%ymm\rl1,%ymm\rl1 -vpaddw %ymm14,%ymm\rh2,%ymm\rh2 - -vpsubw %ymm15,%ymm\rl2,%ymm\rl2 -vpaddw %ymm15,%ymm\rh3,%ymm\rh3 -.endm - -.macro level0 off -vpbroadcastq (_ZETAS_EXP+0)*2(%rsi),%ymm15 -vmovdqa (64*\off+128)*2(%rdi),%ymm8 -vmovdqa (64*\off+144)*2(%rdi),%ymm9 -vmovdqa (64*\off+160)*2(%rdi),%ymm10 -vmovdqa (64*\off+176)*2(%rdi),%ymm11 -vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (64*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (64*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (64*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (64*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -vmovdqa %ymm3,(64*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(64*\off+ 16)*2(%rdi) -vmovdqa %ymm5,(64*\off+ 32)*2(%rdi) -vmovdqa %ymm6,(64*\off+ 48)*2(%rdi) -vmovdqa %ymm8,(64*\off+128)*2(%rdi) -vmovdqa %ymm9,(64*\off+144)*2(%rdi) -vmovdqa %ymm10,(64*\off+160)*2(%rdi) -vmovdqa %ymm11,(64*\off+176)*2(%rdi) -.endm - -.macro levels1t6 off -/* level 1 */ -vmovdqa (_ZETAS_EXP+224*\off+16)*2(%rsi),%ymm15 -vmovdqa (128*\off+ 64)*2(%rdi),%ymm8 -vmovdqa (128*\off+ 80)*2(%rdi),%ymm9 -vmovdqa (128*\off+ 96)*2(%rdi),%ymm10 -vmovdqa (128*\off+112)*2(%rdi),%ymm11 -vmovdqa (_ZETAS_EXP+224*\off+32)*2(%rsi),%ymm2 - -mul 8,9,10,11 - -vmovdqa (128*\off+ 0)*2(%rdi),%ymm4 -vmovdqa (128*\off+ 16)*2(%rdi),%ymm5 -vmovdqa (128*\off+ 32)*2(%rdi),%ymm6 -vmovdqa (128*\off+ 48)*2(%rdi),%ymm7 - -reduce -update 3,4,5,6,7,8,9,10,11 - -/* level 2 */ -shuffle8 5,10,7,10 -shuffle8 6,11,5,11 - -vmovdqa (_ZETAS_EXP+224*\off+48)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+64)*2(%rsi),%ymm2 - -mul 7,10,5,11 - -shuffle8 3,8,6,8 -shuffle8 4,9,3,9 - -reduce -update 4,6,8,3,9,7,10,5,11 - -/* level 3 */ -shuffle4 8,5,9,5 -shuffle4 3,11,8,11 - -vmovdqa (_ZETAS_EXP+224*\off+80)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+96)*2(%rsi),%ymm2 - -mul 9,5,8,11 - -shuffle4 4,7,3,7 -shuffle4 6,10,4,10 - -reduce -update 6,3,7,4,10,9,5,8,11 - -/* level 4 */ -shuffle2 7,8,10,8 -shuffle2 4,11,7,11 - -vmovdqa (_ZETAS_EXP+224*\off+112)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+128)*2(%rsi),%ymm2 - -mul 10,8,7,11 - -shuffle2 6,9,4,9 -shuffle2 3,5,6,5 - -reduce -update 3,4,9,6,5,10,8,7,11 - -/* level 5 */ -shuffle1 9,7,5,7 -shuffle1 6,11,9,11 - -vmovdqa (_ZETAS_EXP+224*\off+144)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+160)*2(%rsi),%ymm2 - -mul 5,7,9,11 - -shuffle1 3,10,6,10 -shuffle1 4,8,3,8 - -reduce -update 4,6,10,3,8,5,7,9,11 - -/* level 6 */ -vmovdqa (_ZETAS_EXP+224*\off+176)*2(%rsi),%ymm14 -vmovdqa (_ZETAS_EXP+224*\off+208)*2(%rsi),%ymm15 -vmovdqa (_ZETAS_EXP+224*\off+192)*2(%rsi),%ymm8 -vmovdqa (_ZETAS_EXP+224*\off+224)*2(%rsi),%ymm2 - -mul 10,3,9,11,14,15,8,2 - -reduce -update 8,4,6,5,7,10,3,9,11 - -vmovdqa %ymm8,(128*\off+ 0)*2(%rdi) -vmovdqa %ymm4,(128*\off+ 16)*2(%rdi) -vmovdqa %ymm10,(128*\off+ 32)*2(%rdi) -vmovdqa %ymm3,(128*\off+ 48)*2(%rdi) -vmovdqa %ymm6,(128*\off+ 64)*2(%rdi) -vmovdqa %ymm5,(128*\off+ 80)*2(%rdi) -vmovdqa %ymm9,(128*\off+ 96)*2(%rdi) -vmovdqa %ymm11,(128*\off+112)*2(%rdi) -.endm - -.text -.global cdecl(ntt_avx) -cdecl(ntt_avx): -vmovdqa _16XQ*2(%rsi),%ymm0 - -level0 0 -level0 1 - -levels1t6 0 -levels1t6 1 - -ret diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.h deleted file mode 100644 index a4f48e343b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/ntt.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include - -#define ntt_avx KYBER_NAMESPACE(ntt_avx) -void ntt_avx(__m256i *r, const __m256i *qdata); -#define invntt_avx KYBER_NAMESPACE(invntt_avx) -void invntt_avx(__m256i *r, const __m256i *qdata); - -#define nttpack_avx KYBER_NAMESPACE(nttpack_avx) -void nttpack_avx(__m256i *r, const __m256i *qdata); -#define nttunpack_avx KYBER_NAMESPACE(nttunpack_avx) -void nttunpack_avx(__m256i *r, const __m256i *qdata); - -#define basemul_avx KYBER_NAMESPACE(basemul_avx) -void basemul_avx(__m256i *r, - const __m256i *a, - const __m256i *b, - const __m256i *qdata); - -#define ntttobytes_avx KYBER_NAMESPACE(ntttobytes_avx) -void ntttobytes_avx(uint8_t *r, const __m256i *a, const __m256i *qdata); -#define nttfrombytes_avx KYBER_NAMESPACE(nttfrombytes_avx) -void nttfrombytes_avx(__m256i *r, const uint8_t *a, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/params.h deleted file mode 100644 index ecfabce4a5..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/params.h +++ /dev/null @@ -1,68 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - -//#define KYBER_90S /* Uncomment this if you want the 90S variant */ - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber512_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_avx2_##s -#endif -#elif (KYBER_K == 3) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber768_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_avx2_##s -#endif -#elif (KYBER_K == 4) -#ifdef KYBER_90S -#define KYBER_NAMESPACE(s) pqcrystals_kyber1024_90s_avx2_##s -#else -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_avx2_##s -#endif -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.c deleted file mode 100644 index 681fd6d23e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.c +++ /dev/null @@ -1,519 +0,0 @@ -#include -#include -#include -#include "align.h" -#include "fips202x4.h" -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -#if (KYBER_POLYCOMPRESSEDBYTES == 128) -void poly_compress(uint8_t r[128], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2, f3; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 9); - const __m256i mask = _mm256_set1_epi16(15); - const __m256i shift2 = _mm256_set1_epi16((16 << 8) + 1); - const __m256i permdidx = _mm256_set_epi32(7,3,6,2,5,1,4,0); - - for(i=0;ivec[4*i+0]); - f1 = _mm256_load_si256(&a->vec[4*i+1]); - f2 = _mm256_load_si256(&a->vec[4*i+2]); - f3 = _mm256_load_si256(&a->vec[4*i+3]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f2 = _mm256_mulhi_epi16(f2,v); - f3 = _mm256_mulhi_epi16(f3,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f2 = _mm256_mulhrs_epi16(f2,shift1); - f3 = _mm256_mulhrs_epi16(f3,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f2 = _mm256_and_si256(f2,mask); - f3 = _mm256_and_si256(f3,mask); - f0 = _mm256_packus_epi16(f0,f1); - f2 = _mm256_packus_epi16(f2,f3); - f0 = _mm256_maddubs_epi16(f0,shift2); - f2 = _mm256_maddubs_epi16(f2,shift2); - f0 = _mm256_packus_epi16(f0,f2); - f0 = _mm256_permutevar8x32_epi32(f0,permdidx); - _mm256_storeu_si256((__m256i *)&r[32*i],f0); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[128]) -{ - unsigned int i; - __m128i t; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(7,7,7,7,6,6,6,6,5,5,5,5,4,4,4,4, - 3,3,3,3,2,2,2,2,1,1,1,1,0,0,0,0); - const __m256i mask = _mm256_set1_epi32(0x00F0000F); - const __m256i shift = _mm256_set1_epi32((128 << 16) + 2048); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) -void poly_compress(uint8_t r[160], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i shift1 = _mm256_set1_epi16(1 << 10); - const __m256i mask = _mm256_set1_epi16(31); - const __m256i shift2 = _mm256_set1_epi16((32 << 8) + 1); - const __m256i shift3 = _mm256_set1_epi32((1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8,-1,-1,-1,-1,-1, 4, 3, 2, 1, 0,-1,12,11,10, 9, - -1,12,11,10, 9, 8,-1,-1,-1,-1,-1 ,4, 3, 2, 1, 0); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_mulhi_epi16(f0,v); - f1 = _mm256_mulhi_epi16(f1,v); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f1 = _mm256_mulhrs_epi16(f1,shift1); - f0 = _mm256_and_si256(f0,mask); - f1 = _mm256_and_si256(f1,mask); - f0 = _mm256_packus_epi16(f0,f1); - f0 = _mm256_maddubs_epi16(f0,shift2); // a0 a1 a2 a3 b0 b1 b2 b3 a4 a5 a6 a7 b4 b5 b6 b7 - f0 = _mm256_madd_epi16(f0,shift3); // a0 a1 b0 b1 a2 a3 b2 b3 - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srlv_epi64(f0,sllvdidx); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -void poly_decompress(poly * restrict r, const uint8_t a[160]) -{ - unsigned int i; - __m128i t; - __m256i f; - int16_t ti; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(9,9,9,8,8,8,8,7,7,6,6,6,6,5,5,5, - 4,4,4,3,3,3,3,2,2,1,1,1,1,0,0,0); - const __m256i mask = _mm256_set_epi16(248,1984,62,496,3968,124,992,31, - 248,1984,62,496,3968,124,992,31); - const __m256i shift = _mm256_set_epi16(128,16,512,64,8,256,32,1024, - 128,16,512,64,8,256,32,1024); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial in NTT representation. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). The coefficients are orderd as output by -* poly_ntt(); the serialized output coefficients are in bitreversed -* order. -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - ntttobytes_avx(r, a->vec, qdata.vec); -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - nttfrombytes_avx(r->vec, a, qdata.vec); -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly * restrict r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ -#if (KYBER_INDCPA_MSGBYTES != 32) -#error "KYBER_INDCPA_MSGBYTES must be equal to 32!" -#endif - __m256i f, g0, g1, g2, g3, h0, h1, h2, h3; - const __m256i shift = _mm256_broadcastsi128_si256(_mm_set_epi32(0,1,2,3)); - const __m256i idx = _mm256_broadcastsi128_si256(_mm_set_epi8(15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0)); - const __m256i hqs = _mm256_set1_epi16((KYBER_Q+1)/2); - -#define FROMMSG64(i) \ - g3 = _mm256_shuffle_epi32(f,0x55*i); \ - g3 = _mm256_sllv_epi32(g3,shift); \ - g3 = _mm256_shuffle_epi8(g3,idx); \ - g0 = _mm256_slli_epi16(g3,12); \ - g1 = _mm256_slli_epi16(g3,8); \ - g2 = _mm256_slli_epi16(g3,4); \ - g0 = _mm256_srai_epi16(g0,15); \ - g1 = _mm256_srai_epi16(g1,15); \ - g2 = _mm256_srai_epi16(g2,15); \ - g3 = _mm256_srai_epi16(g3,15); \ - g0 = _mm256_and_si256(g0,hqs); /* 19 18 17 16 3 2 1 0 */ \ - g1 = _mm256_and_si256(g1,hqs); /* 23 22 21 20 7 6 5 4 */ \ - g2 = _mm256_and_si256(g2,hqs); /* 27 26 25 24 11 10 9 8 */ \ - g3 = _mm256_and_si256(g3,hqs); /* 31 30 29 28 15 14 13 12 */ \ - h0 = _mm256_unpacklo_epi64(g0,g1); \ - h2 = _mm256_unpackhi_epi64(g0,g1); \ - h1 = _mm256_unpacklo_epi64(g2,g3); \ - h3 = _mm256_unpackhi_epi64(g2,g3); \ - g0 = _mm256_permute2x128_si256(h0,h1,0x20); \ - g2 = _mm256_permute2x128_si256(h0,h1,0x31); \ - g1 = _mm256_permute2x128_si256(h2,h3,0x20); \ - g3 = _mm256_permute2x128_si256(h2,h3,0x31); \ - _mm256_store_si256(&r->vec[0+2*i+0],g0); \ - _mm256_store_si256(&r->vec[0+2*i+1],g1); \ - _mm256_store_si256(&r->vec[8+2*i+0],g2); \ - _mm256_store_si256(&r->vec[8+2*i+1],g3) - - f = _mm256_loadu_si256((__m256i *)msg); - FROMMSG64(0); - FROMMSG64(1); - FROMMSG64(2); - FROMMSG64(3); -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message. -* The coefficients of the input polynomial are assumed to -* lie in the invertal [0,q], i.e. the polynomial must be reduced -* by poly_reduce(). -* -* Arguments: - uint8_t *msg: pointer to output message -* - poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly * restrict a) -{ - unsigned int i; - uint32_t small; - __m256i f0, f1, g0, g1; - const __m256i hq = _mm256_set1_epi16((KYBER_Q - 1)/2); - const __m256i hhq = _mm256_set1_epi16((KYBER_Q - 1)/4); - - for(i=0;ivec[2*i+0]); - f1 = _mm256_load_si256(&a->vec[2*i+1]); - f0 = _mm256_sub_epi16(hq, f0); - f1 = _mm256_sub_epi16(hq, f1); - g0 = _mm256_srai_epi16(f0, 15); - g1 = _mm256_srai_epi16(f1, 15); - f0 = _mm256_xor_si256(f0, g0); - f1 = _mm256_xor_si256(f1, g1); - f0 = _mm256_sub_epi16(f0, hhq); - f1 = _mm256_sub_epi16(f1, hhq); - f0 = _mm256_packs_epi16(f0, f1); - f0 = _mm256_permute4x64_epi64(f0, 0xD8); - small = _mm256_movemask_epi8(f0); - memcpy(&msg[4*i], &small, 4); - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA1*KYBER_N/4+32) buf; // +32 bytes as required by poly_cbd_eta1 - prf(buf.coeffs, KYBER_ETA1*KYBER_N/4, seed, nonce); - poly_cbd_eta1(r, buf.vec); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - ALIGNED_UINT8(KYBER_ETA2*KYBER_N/4) buf; - prf(buf.coeffs, KYBER_ETA2*KYBER_N/4, seed, nonce); - poly_cbd_eta2(r, buf.vec); -} - -#ifndef KYBER_90S -#define NOISE_NBLOCKS ((KYBER_ETA1*KYBER_N/4+SHAKE256_RATE-1)/SHAKE256_RATE) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta1(r2, buf[2].vec); - poly_cbd_eta1(r3, buf[3].vec); -} - -#if KYBER_K == 2 -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3) -{ - ALIGNED_UINT8(NOISE_NBLOCKS*SHAKE256_RATE) buf[4]; - __m256i f; - shake256x4incctx state; - - f = _mm256_loadu_si256((__m256i *)seed); - _mm256_store_si256(buf[0].vec, f); - _mm256_store_si256(buf[1].vec, f); - _mm256_store_si256(buf[2].vec, f); - _mm256_store_si256(buf[3].vec, f); - - buf[0].coeffs[32] = nonce0; - buf[1].coeffs[32] = nonce1; - buf[2].coeffs[32] = nonce2; - buf[3].coeffs[32] = nonce3; - - shake256x4_inc_init(&state); - shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33); - shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state); - shake256x4_inc_ctx_release(&state); - - poly_cbd_eta1(r0, buf[0].vec); - poly_cbd_eta1(r1, buf[1].vec); - poly_cbd_eta2(r2, buf[2].vec); - poly_cbd_eta2(r3, buf[3].vec); -} -#endif -#endif - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place. -* Input coefficients assumed to be in normal order, -* output coefficients are in special order that is natural -* for the vectorization. Input coefficients are assumed to be -* bounded by q in absolute value, output coefficients are bounded -* by 16118 in absolute value. -* -* Arguments: - poly *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* Input coefficients assumed to be in special order from vectorized -* forward ntt, output in normal order. Input coefficients can be -* arbitrary 16-bit integers, output coefficients are bounded by 14870 -* in absolute value. -* -* Arguments: - poly *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt_avx(r->vec, qdata.vec); -} - -void poly_nttunpack(poly *r) -{ - nttunpack_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain. -* One of the input polynomials needs to have coefficients -* bounded by q, the other polynomial can have arbitrary -* coefficients. Output coefficients are bounded by 6656. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - basemul_avx(r->vec, a->vec, b->vec, qdata.vec); -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - tomont_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - reduce_avx(r->vec, qdata.vec); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_add_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials. No modular reduction -* is performed. -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - __m256i f0, f1; - - for(i=0;ivec[i]); - f1 = _mm256_load_si256(&b->vec[i]); - f0 = _mm256_sub_epi16(f0, f1); - _mm256_store_si256(&r->vec[i], f0); - } -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.h deleted file mode 100644 index 6a9cf71c70..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/poly.h +++ /dev/null @@ -1,77 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "align.h" -#include "params.h" - -typedef ALIGNED_INT16(KYBER_N) poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#ifndef KYBER_90S -#define poly_getnoise_eta1_4x KYBER_NAMESPACE(poly_getnoise_eta2_4x) -void poly_getnoise_eta1_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); - -#if KYBER_K == 2 -#define poly_getnoise_eta1122_4x KYBER_NAMESPACE(poly_getnoise_eta1122_4x) -void poly_getnoise_eta1122_4x(poly *r0, - poly *r1, - poly *r2, - poly *r3, - const uint8_t seed[32], - uint8_t nonce0, - uint8_t nonce1, - uint8_t nonce2, - uint8_t nonce3); -#endif -#endif - - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_nttunpack KYBER_NAMESPACE(poly_nttunpack) -void poly_nttunpack(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.c deleted file mode 100644 index a0174b7b3f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.c +++ /dev/null @@ -1,307 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "consts.h" - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) -static void poly_compress10(uint8_t r[320], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(15); - const __m256i shift1 = _mm256_set1_epi16(1 << 12); - const __m256i mask = _mm256_set1_epi16(1023); - const __m256i shift2 = _mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(12); - const __m256i shufbidx = _mm256_set_epi8( 8, 4, 3, 2, 1, 0,-1,-1,-1,-1,-1,-1,12,11,10, 9, - -1,-1,-1,-1,-1,-1,12,11,10, 9, 8, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f0 = _mm256_srli_epi64(f0,12); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blend_epi16(t0,t1,0xE0); - _mm_storeu_si128((__m128i *)&r[20*i+ 0],t0); - memcpy(&r[20*i+16],&t1,4); - } -} - -static void poly_decompress10(poly * restrict r, const uint8_t a[320+12]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_set1_epi32((KYBER_Q << 16) + 4*KYBER_Q); - const __m256i shufbidx = _mm256_set_epi8(11,10,10, 9, 9, 8, 8, 7, - 6, 5, 5, 4, 4, 3, 3, 2, - 9, 8, 8, 7, 7, 6, 6, 5, - 4, 3, 3, 2, 2, 1, 1, 0); - const __m256i sllvdidx = _mm256_set1_epi64x(4); - const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184); - - for(i=0;ivec[i],f); - } -} - -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) -static void poly_compress11(uint8_t r[352+2], const poly * restrict a) -{ - unsigned int i; - __m256i f0, f1, f2; - __m128i t0, t1; - const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]); - const __m256i v8 = _mm256_slli_epi16(v,3); - const __m256i off = _mm256_set1_epi16(36); - const __m256i shift1 = _mm256_set1_epi16(1 << 13); - const __m256i mask = _mm256_set1_epi16(2047); - const __m256i shift2 = _mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1); - const __m256i sllvdidx = _mm256_set1_epi64x(10); - const __m256i srlvqidx = _mm256_set_epi64x(30,10,30,10); - const __m256i shufbidx = _mm256_set_epi8( 4, 3, 2, 1, 0, 0,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, - -1,-1,-1,-1,-1,10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); - - for(i=0;ivec[i]); - f1 = _mm256_mullo_epi16(f0,v8); - f2 = _mm256_add_epi16(f0,off); - f0 = _mm256_slli_epi16(f0,3); - f0 = _mm256_mulhi_epi16(f0,v); - f2 = _mm256_sub_epi16(f1,f2); - f1 = _mm256_andnot_si256(f1,f2); - f1 = _mm256_srli_epi16(f1,15); - f0 = _mm256_sub_epi16(f0,f1); - f0 = _mm256_mulhrs_epi16(f0,shift1); - f0 = _mm256_and_si256(f0,mask); - f0 = _mm256_madd_epi16(f0,shift2); - f0 = _mm256_sllv_epi32(f0,sllvdidx); - f1 = _mm256_bsrli_epi128(f0,8); - f0 = _mm256_srlv_epi64(f0,srlvqidx); - f1 = _mm256_slli_epi64(f1,34); - f0 = _mm256_add_epi64(f0,f1); - f0 = _mm256_shuffle_epi8(f0,shufbidx); - t0 = _mm256_castsi256_si128(f0); - t1 = _mm256_extracti128_si256(f0,1); - t0 = _mm_blendv_epi8(t0,t1,_mm256_castsi256_si128(shufbidx)); - _mm_storeu_si128((__m128i *)&r[22*i+ 0],t0); - _mm_storel_epi64((__m128i *)&r[22*i+16],t1); - } -} - -static void poly_decompress11(poly * restrict r, const uint8_t a[352+10]) -{ - unsigned int i; - __m256i f; - const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i shufbidx = _mm256_set_epi8(13,12,12,11,10, 9, 9, 8, - 8, 7, 6, 5, 5, 4, 4, 3, - 10, 9, 9, 8, 7, 6, 6, 5, - 5, 4, 3, 2, 2, 1, 1, 0); - const __m256i srlvdidx = _mm256_set_epi32(0,0,1,0,0,0,1,0); - const __m256i srlvqidx = _mm256_set_epi64x(2,0,2,0); - const __m256i shift = _mm256_set_epi16(4,32,1,8,32,1,4,32,4,32,1,8,32,1,4,32); - const __m256i mask = _mm256_set1_epi16(32752); - - for(i=0;ivec[i],f); - } -} - -#endif - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i]); -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]) -{ - unsigned int i; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - for(i=0;ivec[i],&a[320*i]); -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - for(i=0;ivec[i],&a[352*i]); -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements in a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly tmp; - - poly_basemul_montgomery(r,&a->vec[0],&b->vec[0]); - for(i=1;ivec[i],&b->vec[i]); - poly_add(r,r,&tmp); - } -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.h deleted file mode 100644 index 2ce23c31ff..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES+2], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES+12]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/reduce.h deleted file mode 100644 index 5368185b5f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/reduce.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include "params.h" -#include - -#define reduce_avx KYBER_NAMESPACE(reduce_avx) -void reduce_avx(__m256i *r, const __m256i *qdata); -#define tomont_avx KYBER_NAMESPACE(tomont_avx) -void tomont_avx(__m256i *r, const __m256i *qdata); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.c deleted file mode 100644 index 9060a44cb9..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.c +++ /dev/null @@ -1,398 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "consts.h" -#include "rejsample.h" - -//#define BMI - -#ifndef BMI -static const uint8_t idx[256][8] = { - {-1, -1, -1, -1, -1, -1, -1, -1}, - { 0, -1, -1, -1, -1, -1, -1, -1}, - { 2, -1, -1, -1, -1, -1, -1, -1}, - { 0, 2, -1, -1, -1, -1, -1, -1}, - { 4, -1, -1, -1, -1, -1, -1, -1}, - { 0, 4, -1, -1, -1, -1, -1, -1}, - { 2, 4, -1, -1, -1, -1, -1, -1}, - { 0, 2, 4, -1, -1, -1, -1, -1}, - { 6, -1, -1, -1, -1, -1, -1, -1}, - { 0, 6, -1, -1, -1, -1, -1, -1}, - { 2, 6, -1, -1, -1, -1, -1, -1}, - { 0, 2, 6, -1, -1, -1, -1, -1}, - { 4, 6, -1, -1, -1, -1, -1, -1}, - { 0, 4, 6, -1, -1, -1, -1, -1}, - { 2, 4, 6, -1, -1, -1, -1, -1}, - { 0, 2, 4, 6, -1, -1, -1, -1}, - { 8, -1, -1, -1, -1, -1, -1, -1}, - { 0, 8, -1, -1, -1, -1, -1, -1}, - { 2, 8, -1, -1, -1, -1, -1, -1}, - { 0, 2, 8, -1, -1, -1, -1, -1}, - { 4, 8, -1, -1, -1, -1, -1, -1}, - { 0, 4, 8, -1, -1, -1, -1, -1}, - { 2, 4, 8, -1, -1, -1, -1, -1}, - { 0, 2, 4, 8, -1, -1, -1, -1}, - { 6, 8, -1, -1, -1, -1, -1, -1}, - { 0, 6, 8, -1, -1, -1, -1, -1}, - { 2, 6, 8, -1, -1, -1, -1, -1}, - { 0, 2, 6, 8, -1, -1, -1, -1}, - { 4, 6, 8, -1, -1, -1, -1, -1}, - { 0, 4, 6, 8, -1, -1, -1, -1}, - { 2, 4, 6, 8, -1, -1, -1, -1}, - { 0, 2, 4, 6, 8, -1, -1, -1}, - {10, -1, -1, -1, -1, -1, -1, -1}, - { 0, 10, -1, -1, -1, -1, -1, -1}, - { 2, 10, -1, -1, -1, -1, -1, -1}, - { 0, 2, 10, -1, -1, -1, -1, -1}, - { 4, 10, -1, -1, -1, -1, -1, -1}, - { 0, 4, 10, -1, -1, -1, -1, -1}, - { 2, 4, 10, -1, -1, -1, -1, -1}, - { 0, 2, 4, 10, -1, -1, -1, -1}, - { 6, 10, -1, -1, -1, -1, -1, -1}, - { 0, 6, 10, -1, -1, -1, -1, -1}, - { 2, 6, 10, -1, -1, -1, -1, -1}, - { 0, 2, 6, 10, -1, -1, -1, -1}, - { 4, 6, 10, -1, -1, -1, -1, -1}, - { 0, 4, 6, 10, -1, -1, -1, -1}, - { 2, 4, 6, 10, -1, -1, -1, -1}, - { 0, 2, 4, 6, 10, -1, -1, -1}, - { 8, 10, -1, -1, -1, -1, -1, -1}, - { 0, 8, 10, -1, -1, -1, -1, -1}, - { 2, 8, 10, -1, -1, -1, -1, -1}, - { 0, 2, 8, 10, -1, -1, -1, -1}, - { 4, 8, 10, -1, -1, -1, -1, -1}, - { 0, 4, 8, 10, -1, -1, -1, -1}, - { 2, 4, 8, 10, -1, -1, -1, -1}, - { 0, 2, 4, 8, 10, -1, -1, -1}, - { 6, 8, 10, -1, -1, -1, -1, -1}, - { 0, 6, 8, 10, -1, -1, -1, -1}, - { 2, 6, 8, 10, -1, -1, -1, -1}, - { 0, 2, 6, 8, 10, -1, -1, -1}, - { 4, 6, 8, 10, -1, -1, -1, -1}, - { 0, 4, 6, 8, 10, -1, -1, -1}, - { 2, 4, 6, 8, 10, -1, -1, -1}, - { 0, 2, 4, 6, 8, 10, -1, -1}, - {12, -1, -1, -1, -1, -1, -1, -1}, - { 0, 12, -1, -1, -1, -1, -1, -1}, - { 2, 12, -1, -1, -1, -1, -1, -1}, - { 0, 2, 12, -1, -1, -1, -1, -1}, - { 4, 12, -1, -1, -1, -1, -1, -1}, - { 0, 4, 12, -1, -1, -1, -1, -1}, - { 2, 4, 12, -1, -1, -1, -1, -1}, - { 0, 2, 4, 12, -1, -1, -1, -1}, - { 6, 12, -1, -1, -1, -1, -1, -1}, - { 0, 6, 12, -1, -1, -1, -1, -1}, - { 2, 6, 12, -1, -1, -1, -1, -1}, - { 0, 2, 6, 12, -1, -1, -1, -1}, - { 4, 6, 12, -1, -1, -1, -1, -1}, - { 0, 4, 6, 12, -1, -1, -1, -1}, - { 2, 4, 6, 12, -1, -1, -1, -1}, - { 0, 2, 4, 6, 12, -1, -1, -1}, - { 8, 12, -1, -1, -1, -1, -1, -1}, - { 0, 8, 12, -1, -1, -1, -1, -1}, - { 2, 8, 12, -1, -1, -1, -1, -1}, - { 0, 2, 8, 12, -1, -1, -1, -1}, - { 4, 8, 12, -1, -1, -1, -1, -1}, - { 0, 4, 8, 12, -1, -1, -1, -1}, - { 2, 4, 8, 12, -1, -1, -1, -1}, - { 0, 2, 4, 8, 12, -1, -1, -1}, - { 6, 8, 12, -1, -1, -1, -1, -1}, - { 0, 6, 8, 12, -1, -1, -1, -1}, - { 2, 6, 8, 12, -1, -1, -1, -1}, - { 0, 2, 6, 8, 12, -1, -1, -1}, - { 4, 6, 8, 12, -1, -1, -1, -1}, - { 0, 4, 6, 8, 12, -1, -1, -1}, - { 2, 4, 6, 8, 12, -1, -1, -1}, - { 0, 2, 4, 6, 8, 12, -1, -1}, - {10, 12, -1, -1, -1, -1, -1, -1}, - { 0, 10, 12, -1, -1, -1, -1, -1}, - { 2, 10, 12, -1, -1, -1, -1, -1}, - { 0, 2, 10, 12, -1, -1, -1, -1}, - { 4, 10, 12, -1, -1, -1, -1, -1}, - { 0, 4, 10, 12, -1, -1, -1, -1}, - { 2, 4, 10, 12, -1, -1, -1, -1}, - { 0, 2, 4, 10, 12, -1, -1, -1}, - { 6, 10, 12, -1, -1, -1, -1, -1}, - { 0, 6, 10, 12, -1, -1, -1, -1}, - { 2, 6, 10, 12, -1, -1, -1, -1}, - { 0, 2, 6, 10, 12, -1, -1, -1}, - { 4, 6, 10, 12, -1, -1, -1, -1}, - { 0, 4, 6, 10, 12, -1, -1, -1}, - { 2, 4, 6, 10, 12, -1, -1, -1}, - { 0, 2, 4, 6, 10, 12, -1, -1}, - { 8, 10, 12, -1, -1, -1, -1, -1}, - { 0, 8, 10, 12, -1, -1, -1, -1}, - { 2, 8, 10, 12, -1, -1, -1, -1}, - { 0, 2, 8, 10, 12, -1, -1, -1}, - { 4, 8, 10, 12, -1, -1, -1, -1}, - { 0, 4, 8, 10, 12, -1, -1, -1}, - { 2, 4, 8, 10, 12, -1, -1, -1}, - { 0, 2, 4, 8, 10, 12, -1, -1}, - { 6, 8, 10, 12, -1, -1, -1, -1}, - { 0, 6, 8, 10, 12, -1, -1, -1}, - { 2, 6, 8, 10, 12, -1, -1, -1}, - { 0, 2, 6, 8, 10, 12, -1, -1}, - { 4, 6, 8, 10, 12, -1, -1, -1}, - { 0, 4, 6, 8, 10, 12, -1, -1}, - { 2, 4, 6, 8, 10, 12, -1, -1}, - { 0, 2, 4, 6, 8, 10, 12, -1}, - {14, -1, -1, -1, -1, -1, -1, -1}, - { 0, 14, -1, -1, -1, -1, -1, -1}, - { 2, 14, -1, -1, -1, -1, -1, -1}, - { 0, 2, 14, -1, -1, -1, -1, -1}, - { 4, 14, -1, -1, -1, -1, -1, -1}, - { 0, 4, 14, -1, -1, -1, -1, -1}, - { 2, 4, 14, -1, -1, -1, -1, -1}, - { 0, 2, 4, 14, -1, -1, -1, -1}, - { 6, 14, -1, -1, -1, -1, -1, -1}, - { 0, 6, 14, -1, -1, -1, -1, -1}, - { 2, 6, 14, -1, -1, -1, -1, -1}, - { 0, 2, 6, 14, -1, -1, -1, -1}, - { 4, 6, 14, -1, -1, -1, -1, -1}, - { 0, 4, 6, 14, -1, -1, -1, -1}, - { 2, 4, 6, 14, -1, -1, -1, -1}, - { 0, 2, 4, 6, 14, -1, -1, -1}, - { 8, 14, -1, -1, -1, -1, -1, -1}, - { 0, 8, 14, -1, -1, -1, -1, -1}, - { 2, 8, 14, -1, -1, -1, -1, -1}, - { 0, 2, 8, 14, -1, -1, -1, -1}, - { 4, 8, 14, -1, -1, -1, -1, -1}, - { 0, 4, 8, 14, -1, -1, -1, -1}, - { 2, 4, 8, 14, -1, -1, -1, -1}, - { 0, 2, 4, 8, 14, -1, -1, -1}, - { 6, 8, 14, -1, -1, -1, -1, -1}, - { 0, 6, 8, 14, -1, -1, -1, -1}, - { 2, 6, 8, 14, -1, -1, -1, -1}, - { 0, 2, 6, 8, 14, -1, -1, -1}, - { 4, 6, 8, 14, -1, -1, -1, -1}, - { 0, 4, 6, 8, 14, -1, -1, -1}, - { 2, 4, 6, 8, 14, -1, -1, -1}, - { 0, 2, 4, 6, 8, 14, -1, -1}, - {10, 14, -1, -1, -1, -1, -1, -1}, - { 0, 10, 14, -1, -1, -1, -1, -1}, - { 2, 10, 14, -1, -1, -1, -1, -1}, - { 0, 2, 10, 14, -1, -1, -1, -1}, - { 4, 10, 14, -1, -1, -1, -1, -1}, - { 0, 4, 10, 14, -1, -1, -1, -1}, - { 2, 4, 10, 14, -1, -1, -1, -1}, - { 0, 2, 4, 10, 14, -1, -1, -1}, - { 6, 10, 14, -1, -1, -1, -1, -1}, - { 0, 6, 10, 14, -1, -1, -1, -1}, - { 2, 6, 10, 14, -1, -1, -1, -1}, - { 0, 2, 6, 10, 14, -1, -1, -1}, - { 4, 6, 10, 14, -1, -1, -1, -1}, - { 0, 4, 6, 10, 14, -1, -1, -1}, - { 2, 4, 6, 10, 14, -1, -1, -1}, - { 0, 2, 4, 6, 10, 14, -1, -1}, - { 8, 10, 14, -1, -1, -1, -1, -1}, - { 0, 8, 10, 14, -1, -1, -1, -1}, - { 2, 8, 10, 14, -1, -1, -1, -1}, - { 0, 2, 8, 10, 14, -1, -1, -1}, - { 4, 8, 10, 14, -1, -1, -1, -1}, - { 0, 4, 8, 10, 14, -1, -1, -1}, - { 2, 4, 8, 10, 14, -1, -1, -1}, - { 0, 2, 4, 8, 10, 14, -1, -1}, - { 6, 8, 10, 14, -1, -1, -1, -1}, - { 0, 6, 8, 10, 14, -1, -1, -1}, - { 2, 6, 8, 10, 14, -1, -1, -1}, - { 0, 2, 6, 8, 10, 14, -1, -1}, - { 4, 6, 8, 10, 14, -1, -1, -1}, - { 0, 4, 6, 8, 10, 14, -1, -1}, - { 2, 4, 6, 8, 10, 14, -1, -1}, - { 0, 2, 4, 6, 8, 10, 14, -1}, - {12, 14, -1, -1, -1, -1, -1, -1}, - { 0, 12, 14, -1, -1, -1, -1, -1}, - { 2, 12, 14, -1, -1, -1, -1, -1}, - { 0, 2, 12, 14, -1, -1, -1, -1}, - { 4, 12, 14, -1, -1, -1, -1, -1}, - { 0, 4, 12, 14, -1, -1, -1, -1}, - { 2, 4, 12, 14, -1, -1, -1, -1}, - { 0, 2, 4, 12, 14, -1, -1, -1}, - { 6, 12, 14, -1, -1, -1, -1, -1}, - { 0, 6, 12, 14, -1, -1, -1, -1}, - { 2, 6, 12, 14, -1, -1, -1, -1}, - { 0, 2, 6, 12, 14, -1, -1, -1}, - { 4, 6, 12, 14, -1, -1, -1, -1}, - { 0, 4, 6, 12, 14, -1, -1, -1}, - { 2, 4, 6, 12, 14, -1, -1, -1}, - { 0, 2, 4, 6, 12, 14, -1, -1}, - { 8, 12, 14, -1, -1, -1, -1, -1}, - { 0, 8, 12, 14, -1, -1, -1, -1}, - { 2, 8, 12, 14, -1, -1, -1, -1}, - { 0, 2, 8, 12, 14, -1, -1, -1}, - { 4, 8, 12, 14, -1, -1, -1, -1}, - { 0, 4, 8, 12, 14, -1, -1, -1}, - { 2, 4, 8, 12, 14, -1, -1, -1}, - { 0, 2, 4, 8, 12, 14, -1, -1}, - { 6, 8, 12, 14, -1, -1, -1, -1}, - { 0, 6, 8, 12, 14, -1, -1, -1}, - { 2, 6, 8, 12, 14, -1, -1, -1}, - { 0, 2, 6, 8, 12, 14, -1, -1}, - { 4, 6, 8, 12, 14, -1, -1, -1}, - { 0, 4, 6, 8, 12, 14, -1, -1}, - { 2, 4, 6, 8, 12, 14, -1, -1}, - { 0, 2, 4, 6, 8, 12, 14, -1}, - {10, 12, 14, -1, -1, -1, -1, -1}, - { 0, 10, 12, 14, -1, -1, -1, -1}, - { 2, 10, 12, 14, -1, -1, -1, -1}, - { 0, 2, 10, 12, 14, -1, -1, -1}, - { 4, 10, 12, 14, -1, -1, -1, -1}, - { 0, 4, 10, 12, 14, -1, -1, -1}, - { 2, 4, 10, 12, 14, -1, -1, -1}, - { 0, 2, 4, 10, 12, 14, -1, -1}, - { 6, 10, 12, 14, -1, -1, -1, -1}, - { 0, 6, 10, 12, 14, -1, -1, -1}, - { 2, 6, 10, 12, 14, -1, -1, -1}, - { 0, 2, 6, 10, 12, 14, -1, -1}, - { 4, 6, 10, 12, 14, -1, -1, -1}, - { 0, 4, 6, 10, 12, 14, -1, -1}, - { 2, 4, 6, 10, 12, 14, -1, -1}, - { 0, 2, 4, 6, 10, 12, 14, -1}, - { 8, 10, 12, 14, -1, -1, -1, -1}, - { 0, 8, 10, 12, 14, -1, -1, -1}, - { 2, 8, 10, 12, 14, -1, -1, -1}, - { 0, 2, 8, 10, 12, 14, -1, -1}, - { 4, 8, 10, 12, 14, -1, -1, -1}, - { 0, 4, 8, 10, 12, 14, -1, -1}, - { 2, 4, 8, 10, 12, 14, -1, -1}, - { 0, 2, 4, 8, 10, 12, 14, -1}, - { 6, 8, 10, 12, 14, -1, -1, -1}, - { 0, 6, 8, 10, 12, 14, -1, -1}, - { 2, 6, 8, 10, 12, 14, -1, -1}, - { 0, 2, 6, 8, 10, 12, 14, -1}, - { 4, 6, 8, 10, 12, 14, -1, -1}, - { 0, 4, 6, 8, 10, 12, 14, -1}, - { 2, 4, 6, 8, 10, 12, 14, -1}, - { 0, 2, 4, 6, 8, 10, 12, 14} -}; -#endif - -#define _mm256_cmpge_epu16(a, b) _mm256_cmpeq_epi16(_mm256_max_epu16(a, b), a) -#define _mm_cmpge_epu16(a, b) _mm_cmpeq_epi16(_mm_max_epu16(a, b), a) - -unsigned int rej_uniform_avx(int16_t * restrict r, const uint8_t *buf) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - uint32_t good; -#ifdef BMI - uint64_t idx0, idx1, idx2, idx3; -#endif - const __m256i bound = _mm256_load_si256(&qdata.vec[_16XQ/16]); - const __m256i ones = _mm256_set1_epi8(1); - const __m256i mask = _mm256_set1_epi16(0xFFF); - const __m256i idx8 = _mm256_set_epi8(15,14,14,13,12,11,11,10, - 9, 8, 8, 7, 6, 5, 5, 4, - 11,10,10, 9, 8, 7, 7, 6, - 5, 4, 4, 3, 2, 1, 1, 0); - __m256i f0, f1, g0, g1, g2, g3; - __m128i f, t, pilo, pihi; - - ctr = pos = 0; - while(ctr <= KYBER_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 56) { - f0 = _mm256_loadu_si256((__m256i *)&buf[pos]); - f1 = _mm256_loadu_si256((__m256i *)&buf[pos+24]); - f0 = _mm256_permute4x64_epi64(f0, 0x94); - f1 = _mm256_permute4x64_epi64(f1, 0x94); - f0 = _mm256_shuffle_epi8(f0, idx8); - f1 = _mm256_shuffle_epi8(f1, idx8); - g0 = _mm256_srli_epi16(f0, 4); - g1 = _mm256_srli_epi16(f1, 4); - f0 = _mm256_blend_epi16(f0, g0, 0xAA); - f1 = _mm256_blend_epi16(f1, g1, 0xAA); - f0 = _mm256_and_si256(f0, mask); - f1 = _mm256_and_si256(f1, mask); - pos += 48; - - g0 = _mm256_cmpgt_epi16(bound, f0); - g1 = _mm256_cmpgt_epi16(bound, f1); - - g0 = _mm256_packs_epi16(g0, g1); - good = _mm256_movemask_epi8(g0); - -#ifdef BMI - idx0 = _pdep_u64(good >> 0, 0x0101010101010101); - idx1 = _pdep_u64(good >> 8, 0x0101010101010101); - idx2 = _pdep_u64(good >> 16, 0x0101010101010101); - idx3 = _pdep_u64(good >> 24, 0x0101010101010101); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - idx1 = (idx1 << 8) - idx1; - idx1 = _pext_u64(0x0E0C0A0806040200, idx1); - idx2 = (idx2 << 8) - idx2; - idx2 = _pext_u64(0x0E0C0A0806040200, idx2); - idx3 = (idx3 << 8) - idx3; - idx3 = _pext_u64(0x0E0C0A0806040200, idx3); - - g0 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx0)); - g1 = _mm256_castsi128_si256(_mm_cvtsi64_si128(idx1)); - g0 = _mm256_inserti128_si256(g0, _mm_cvtsi64_si128(idx2), 1); - g1 = _mm256_inserti128_si256(g1, _mm_cvtsi64_si128(idx3), 1); -#else - g0 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 0) & 0xFF])); - g1 = _mm256_castsi128_si256(_mm_loadl_epi64((__m128i *)&idx[(good >> 8) & 0xFF])); - g0 = _mm256_inserti128_si256(g0, _mm_loadl_epi64((__m128i *)&idx[(good >> 16) & 0xFF]), 1); - g1 = _mm256_inserti128_si256(g1, _mm_loadl_epi64((__m128i *)&idx[(good >> 24) & 0xFF]), 1); -#endif - - g2 = _mm256_add_epi8(g0, ones); - g3 = _mm256_add_epi8(g1, ones); - g0 = _mm256_unpacklo_epi8(g0, g2); - g1 = _mm256_unpacklo_epi8(g1, g3); - - f0 = _mm256_shuffle_epi8(f0, g0); - f1 = _mm256_shuffle_epi8(f1, g1); - - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f0)); - ctr += _mm_popcnt_u32((good >> 0) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f0, 1)); - ctr += _mm_popcnt_u32((good >> 16) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_castsi256_si128(f1)); - ctr += _mm_popcnt_u32((good >> 8) & 0xFF); - _mm_storeu_si128((__m128i *)&r[ctr], _mm256_extracti128_si256(f1, 1)); - ctr += _mm_popcnt_u32((good >> 24) & 0xFF); - } - - while(ctr <= KYBER_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 16) { - f = _mm_loadu_si128((__m128i *)&buf[pos]); - f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8)); - t = _mm_srli_epi16(f, 4); - f = _mm_blend_epi16(f, t, 0xAA); - f = _mm_and_si128(f, _mm256_castsi256_si128(mask)); - pos += 12; - - t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f); - good = _mm_movemask_epi8(t); - -#ifdef BMI - good &= 0x5555; - idx0 = _pdep_u64(good, 0x1111111111111111); - idx0 = (idx0 << 8) - idx0; - idx0 = _pext_u64(0x0E0C0A0806040200, idx0); - pilo = _mm_cvtsi64_si128(idx0); -#else - good = _pext_u32(good, 0x5555); - pilo = _mm_loadl_epi64((__m128i *)&idx[good]); -#endif - - pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones)); - pilo = _mm_unpacklo_epi8(pilo, pihi); - f = _mm_shuffle_epi8(f, pilo); - _mm_storeu_si128((__m128i *)&r[ctr], f); - ctr += _mm_popcnt_u32(good); - } - - while(ctr < KYBER_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)); - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(val1 < KYBER_Q && ctr < KYBER_N) - r[ctr++] = val1; - } - - return ctr; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.h deleted file mode 100644 index 3be5e2192e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/rejsample.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef REJSAMPLE_H -#define REJSAMPLE_H - -#include -#include "params.h" -#include "symmetric.h" - -#define REJ_UNIFORM_AVX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -#define REJ_UNIFORM_AVX_BUFLEN (REJ_UNIFORM_AVX_NBLOCKS*XOF_BLOCKBYTES) - -#define rej_uniform_avx KYBER_NAMESPACE(rej_uniform_avx) -unsigned int rej_uniform_avx(int16_t *r, const uint8_t *buf); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric.h deleted file mode 100644 index e4941f7a86..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/symmetric.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" -#include "fips202x4.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/verify.c deleted file mode 100644 index 06243b837f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_avx2/verify.c +++ /dev/null @@ -1,83 +0,0 @@ -#include -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint64_t r; - __m256i f, g, h; - - h = _mm256_setzero_si256(); - for(i=0;i> 63; - return r; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t * restrict r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - __m256i xvec, rvec, bvec; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - bvec = _mm256_set1_epi64x(-(uint64_t)b); - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/api.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/api.h deleted file mode 100644 index 70d40f3f3e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/api.h +++ /dev/null @@ -1,66 +0,0 @@ -#ifndef API_H -#define API_H - -#include - -#define pqcrystals_kyber512_SECRETKEYBYTES 1632 -#define pqcrystals_kyber512_PUBLICKEYBYTES 800 -#define pqcrystals_kyber512_CIPHERTEXTBYTES 768 -#define pqcrystals_kyber512_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber512_ENCCOINBYTES 32 -#define pqcrystals_kyber512_BYTES 32 - -#define pqcrystals_kyber512_ref_SECRETKEYBYTES pqcrystals_kyber512_SECRETKEYBYTES -#define pqcrystals_kyber512_ref_PUBLICKEYBYTES pqcrystals_kyber512_PUBLICKEYBYTES -#define pqcrystals_kyber512_ref_CIPHERTEXTBYTES pqcrystals_kyber512_CIPHERTEXTBYTES -#define pqcrystals_kyber512_ref_KEYPAIRCOINBYTES pqcrystals_kyber512_KEYPAIRCOINBYTES -#define pqcrystals_kyber512_ref_ENCCOINBYTES pqcrystals_kyber512_ENCCOINBYTES -#define pqcrystals_kyber512_ref_BYTES pqcrystals_kyber512_BYTES - -int pqcrystals_kyber512_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber512_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber512_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber512_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber512_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber768_SECRETKEYBYTES 2400 -#define pqcrystals_kyber768_PUBLICKEYBYTES 1184 -#define pqcrystals_kyber768_CIPHERTEXTBYTES 1088 -#define pqcrystals_kyber768_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber768_ENCCOINBYTES 32 -#define pqcrystals_kyber768_BYTES 32 - -#define pqcrystals_kyber768_ref_SECRETKEYBYTES pqcrystals_kyber768_SECRETKEYBYTES -#define pqcrystals_kyber768_ref_PUBLICKEYBYTES pqcrystals_kyber768_PUBLICKEYBYTES -#define pqcrystals_kyber768_ref_CIPHERTEXTBYTES pqcrystals_kyber768_CIPHERTEXTBYTES -#define pqcrystals_kyber768_ref_KEYPAIRCOINBYTES pqcrystals_kyber768_KEYPAIRCOINBYTES -#define pqcrystals_kyber768_ref_ENCCOINBYTES pqcrystals_kyber768_ENCCOINBYTES -#define pqcrystals_kyber768_ref_BYTES pqcrystals_kyber768_BYTES - -int pqcrystals_kyber768_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber768_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber768_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber768_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber768_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#define pqcrystals_kyber1024_SECRETKEYBYTES 3168 -#define pqcrystals_kyber1024_PUBLICKEYBYTES 1568 -#define pqcrystals_kyber1024_CIPHERTEXTBYTES 1568 -#define pqcrystals_kyber1024_KEYPAIRCOINBYTES 64 -#define pqcrystals_kyber1024_ENCCOINBYTES 32 -#define pqcrystals_kyber1024_BYTES 32 - -#define pqcrystals_kyber1024_ref_SECRETKEYBYTES pqcrystals_kyber1024_SECRETKEYBYTES -#define pqcrystals_kyber1024_ref_PUBLICKEYBYTES pqcrystals_kyber1024_PUBLICKEYBYTES -#define pqcrystals_kyber1024_ref_CIPHERTEXTBYTES pqcrystals_kyber1024_CIPHERTEXTBYTES -#define pqcrystals_kyber1024_ref_KEYPAIRCOINBYTES pqcrystals_kyber1024_KEYPAIRCOINBYTES -#define pqcrystals_kyber1024_ref_ENCCOINBYTES pqcrystals_kyber1024_ENCCOINBYTES -#define pqcrystals_kyber1024_ref_BYTES pqcrystals_kyber1024_BYTES - -int pqcrystals_kyber1024_ref_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_keypair(uint8_t *pk, uint8_t *sk); -int pqcrystals_kyber1024_ref_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); -int pqcrystals_kyber1024_ref_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); -int pqcrystals_kyber1024_ref_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.c deleted file mode 100644 index 1500ffea56..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.c +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include "params.h" -#include "cbd.h" - -/************************************************* -* Name: load32_littleendian -* -* Description: load 4 bytes into a 32-bit integer -* in little-endian order -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x -**************************************************/ -static uint32_t load32_littleendian(const uint8_t x[4]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - r |= (uint32_t)x[3] << 24; - return r; -} - -/************************************************* -* Name: load24_littleendian -* -* Description: load 3 bytes into a 32-bit integer -* in little-endian order. -* This function is only needed for Kyber-512 -* -* Arguments: - const uint8_t *x: pointer to input byte array -* -* Returns 32-bit unsigned integer loaded from x (most significant byte is zero) -**************************************************/ -#if KYBER_ETA1 == 3 -static uint32_t load24_littleendian(const uint8_t x[3]) -{ - uint32_t r; - r = (uint32_t)x[0]; - r |= (uint32_t)x[1] << 8; - r |= (uint32_t)x[2] << 16; - return r; -} -#endif - - -/************************************************* -* Name: cbd2 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -static void cbd2(poly *r, const uint8_t buf[2*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x55555555; - - for(j=0;j<8;j++) { - a = (d >> (4*j+0)) & 0x3; - b = (d >> (4*j+2)) & 0x3; - r->coeffs[8*i+j] = a - b; - } - } -} - -/************************************************* -* Name: cbd3 -* -* Description: Given an array of uniformly random bytes, compute -* polynomial with coefficients distributed according to -* a centered binomial distribution with parameter eta=3. -* This function is only needed for Kyber-512 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *buf: pointer to input byte array -**************************************************/ -#if KYBER_ETA1 == 3 -static void cbd3(poly *r, const uint8_t buf[3*KYBER_N/4]) -{ - unsigned int i,j; - uint32_t t,d; - int16_t a,b; - - for(i=0;i>1) & 0x00249249; - d += (t>>2) & 0x00249249; - - for(j=0;j<4;j++) { - a = (d >> (6*j+0)) & 0x7; - b = (d >> (6*j+3)) & 0x7; - r->coeffs[4*i+j] = a - b; - } - } -} -#endif - -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]) -{ -#if KYBER_ETA1 == 2 - cbd2(r, buf); -#elif KYBER_ETA1 == 3 - cbd3(r, buf); -#else -#error "This implementation requires eta1 in {2,3}" -#endif -} - -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]) -{ -#if KYBER_ETA2 == 2 - cbd2(r, buf); -#else -#error "This implementation requires eta2 = 2" -#endif -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.h deleted file mode 100644 index 7b677d745d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/cbd.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CBD_H -#define CBD_H - -#include -#include "params.h" -#include "poly.h" - -#define poly_cbd_eta1 KYBER_NAMESPACE(poly_cbd_eta1) -void poly_cbd_eta1(poly *r, const uint8_t buf[KYBER_ETA1*KYBER_N/4]); - -#define poly_cbd_eta2 KYBER_NAMESPACE(poly_cbd_eta2) -void poly_cbd_eta2(poly *r, const uint8_t buf[KYBER_ETA2*KYBER_N/4]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.c deleted file mode 100644 index 726cfa985d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/indcpa.c +++ /dev/null @@ -1,334 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "indcpa.h" -#include "polyvec.h" -#include "poly.h" -#include "ntt.h" -#include "symmetric.h" -#include "randombytes.h" - -/************************************************* -* Name: pack_pk -* -* Description: Serialize the public key as concatenation of the -* serialized vector of polynomials pk -* and the public seed used to generate the matrix A. -* -* Arguments: uint8_t *r: pointer to the output serialized public key -* polyvec *pk: pointer to the input public-key polyvec -* const uint8_t *seed: pointer to the input public seed -**************************************************/ -static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], - polyvec *pk, - const uint8_t seed[KYBER_SYMBYTES]) -{ - polyvec_tobytes(r, pk); - memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES); -} - -/************************************************* -* Name: unpack_pk -* -* Description: De-serialize public key from a byte array; -* approximate inverse of pack_pk -* -* Arguments: - polyvec *pk: pointer to output public-key polynomial vector -* - uint8_t *seed: pointer to output seed to generate matrix A -* - const uint8_t *packedpk: pointer to input serialized public key -**************************************************/ -static void unpack_pk(polyvec *pk, - uint8_t seed[KYBER_SYMBYTES], - const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) -{ - polyvec_frombytes(pk, packedpk); - memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES); -} - -/************************************************* -* Name: pack_sk -* -* Description: Serialize the secret key -* -* Arguments: - uint8_t *r: pointer to output serialized secret key -* - polyvec *sk: pointer to input vector of polynomials (secret key) -**************************************************/ -static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) -{ - polyvec_tobytes(r, sk); -} - -/************************************************* -* Name: unpack_sk -* -* Description: De-serialize the secret key; inverse of pack_sk -* -* Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key) -* - const uint8_t *packedsk: pointer to input serialized secret key -**************************************************/ -static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) -{ - polyvec_frombytes(sk, packedsk); -} - -/************************************************* -* Name: pack_ciphertext -* -* Description: Serialize the ciphertext as concatenation of the -* compressed and serialized vector of polynomials b -* and the compressed and serialized polynomial v -* -* Arguments: uint8_t *r: pointer to the output serialized ciphertext -* poly *pk: pointer to the input vector of polynomials b -* poly *v: pointer to the input polynomial v -**************************************************/ -static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v) -{ - polyvec_compress(r, b); - poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v); -} - -/************************************************* -* Name: unpack_ciphertext -* -* Description: De-serialize and decompress ciphertext from a byte array; -* approximate inverse of pack_ciphertext -* -* Arguments: - polyvec *b: pointer to the output vector of polynomials b -* - poly *v: pointer to the output polynomial v -* - const uint8_t *c: pointer to the input serialized ciphertext -**************************************************/ -static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES]) -{ - polyvec_decompress(b, c); - poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES); -} - -/************************************************* -* Name: rej_uniform -* -* Description: Run rejection sampling on uniform random bytes to generate -* uniform random integers mod q -* -* Arguments: - int16_t *r: pointer to output buffer -* - unsigned int len: requested number of 16-bit integers (uniform mod q) -* - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes) -* - unsigned int buflen: length of input buffer in bytes -* -* Returns number of sampled 16-bit integers (at most len) -**************************************************/ -static unsigned int rej_uniform(int16_t *r, - unsigned int len, - const uint8_t *buf, - unsigned int buflen) -{ - unsigned int ctr, pos; - uint16_t val0, val1; - - ctr = pos = 0; - while(ctr < len && pos + 3 <= buflen) { - val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF; - val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF; - pos += 3; - - if(val0 < KYBER_Q) - r[ctr++] = val0; - if(ctr < len && val1 < KYBER_Q) - r[ctr++] = val1; - } - - return ctr; -} - -#define gen_a(A,B) gen_matrix(A,B,0) -#define gen_at(A,B) gen_matrix(A,B,1) - -/************************************************* -* Name: gen_matrix -* -* Description: Deterministically generate matrix A (or the transpose of A) -* from a seed. Entries of the matrix are polynomials that look -* uniformly random. Performs rejection sampling on output of -* a XOF -* -* Arguments: - polyvec *a: pointer to ouptput matrix A -* - const uint8_t *seed: pointer to input seed -* - int transposed: boolean deciding whether A or A^T is generated -**************************************************/ -#if(XOF_BLOCKBYTES % 3) -#error "Implementation of gen_matrix assumes that XOF_BLOCKBYTES is a multiple of 3" -#endif - -#define GEN_MATRIX_NBLOCKS ((12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) -// Not static for benchmarking -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed) -{ - unsigned int ctr, i, j; - unsigned int buflen; - uint8_t buf[GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]; - xof_state state; - xof_init(&state, seed); - - for(i=0;i -#include "params.h" -#include "polyvec.h" - -#define gen_matrix KYBER_NAMESPACE(gen_matrix) -void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES], int transposed); - -#define indcpa_keypair_derand KYBER_NAMESPACE(indcpa_keypair_derand) -void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_enc KYBER_NAMESPACE(indcpa_enc) -void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES], - const uint8_t coins[KYBER_SYMBYTES]); - -#define indcpa_dec KYBER_NAMESPACE(indcpa_dec) -void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES], - const uint8_t c[KYBER_INDCPA_BYTES], - const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.c deleted file mode 100644 index 63abc1029c..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.c +++ /dev/null @@ -1,169 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "kem.h" -#include "indcpa.h" -#include "verify.h" -#include "symmetric.h" -#include "randombytes.h" -/************************************************* -* Name: crypto_kem_keypair_derand -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* - uint8_t *coins: pointer to input randomness -* (an already allocated array filled with 2*KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair_derand(uint8_t *pk, - uint8_t *sk, - const uint8_t *coins) -{ - indcpa_keypair_derand(pk, sk, coins); - memcpy(sk+KYBER_INDCPA_SECRETKEYBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_h(sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - /* Value z for pseudo-random output on reject */ - memcpy(sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES, coins+KYBER_SYMBYTES, KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_keypair -* -* Description: Generates public and private key -* for CCA-secure Kyber key encapsulation mechanism -* -* Arguments: - uint8_t *pk: pointer to output public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - uint8_t *sk: pointer to output private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_keypair(uint8_t *pk, - uint8_t *sk) -{ - uint8_t coins[2*KYBER_SYMBYTES]; - randombytes(coins, 2*KYBER_SYMBYTES); - crypto_kem_keypair_derand(pk, sk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc_derand -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* - const uint8_t *coins: pointer to input randomness -* (an already allocated array filled with KYBER_SYMBYTES random bytes) -** -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc_derand(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk, - const uint8_t *coins) -{ - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - - memcpy(buf, coins, KYBER_SYMBYTES); - - /* Multitarget countermeasure for coins + contributory KEM */ - hash_h(buf+KYBER_SYMBYTES, pk, KYBER_PUBLICKEYBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(ct, buf, pk, kr+KYBER_SYMBYTES); - - memcpy(ss,kr,KYBER_SYMBYTES); - return 0; -} - -/************************************************* -* Name: crypto_kem_enc -* -* Description: Generates cipher text and shared -* secret for given public key -* -* Arguments: - uint8_t *ct: pointer to output cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *pk: pointer to input public key -* (an already allocated array of KYBER_PUBLICKEYBYTES bytes) -* -* Returns 0 (success) -**************************************************/ -int crypto_kem_enc(uint8_t *ct, - uint8_t *ss, - const uint8_t *pk) -{ - uint8_t coins[KYBER_SYMBYTES]; - randombytes(coins, KYBER_SYMBYTES); - crypto_kem_enc_derand(ct, ss, pk, coins); - return 0; -} - -/************************************************* -* Name: crypto_kem_dec -* -* Description: Generates shared secret for given -* cipher text and private key -* -* Arguments: - uint8_t *ss: pointer to output shared secret -* (an already allocated array of KYBER_SSBYTES bytes) -* - const uint8_t *ct: pointer to input cipher text -* (an already allocated array of KYBER_CIPHERTEXTBYTES bytes) -* - const uint8_t *sk: pointer to input private key -* (an already allocated array of KYBER_SECRETKEYBYTES bytes) -* -* Returns 0. -* -* On failure, ss will contain a pseudo-random value. -**************************************************/ -int crypto_kem_dec(uint8_t *ss, - const uint8_t *ct, - const uint8_t *sk) -{ - int fail; - uint8_t buf[2*KYBER_SYMBYTES]; - /* Will contain key, coins */ - uint8_t kr[2*KYBER_SYMBYTES]; - uint8_t cmp[KYBER_CIPHERTEXTBYTES+KYBER_SYMBYTES]; - const uint8_t *pk = sk+KYBER_INDCPA_SECRETKEYBYTES; - - indcpa_dec(buf, ct, sk); - - /* Multitarget countermeasure for coins + contributory KEM */ - memcpy(buf+KYBER_SYMBYTES, sk+KYBER_SECRETKEYBYTES-2*KYBER_SYMBYTES, KYBER_SYMBYTES); - hash_g(kr, buf, 2*KYBER_SYMBYTES); - - /* coins are in kr+KYBER_SYMBYTES */ - indcpa_enc(cmp, buf, pk, kr+KYBER_SYMBYTES); - - fail = verify(ct, cmp, KYBER_CIPHERTEXTBYTES); - - /* Compute rejection key */ - rkprf(ss,sk+KYBER_SECRETKEYBYTES-KYBER_SYMBYTES,ct); - - /* Copy true key to return buffer if fail is false */ - cmov(ss,kr,KYBER_SYMBYTES,!fail); - - return 0; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.h deleted file mode 100644 index 234f11966b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/kem.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef KEM_H -#define KEM_H - -#include -#include "params.h" - -#define CRYPTO_SECRETKEYBYTES KYBER_SECRETKEYBYTES -#define CRYPTO_PUBLICKEYBYTES KYBER_PUBLICKEYBYTES -#define CRYPTO_CIPHERTEXTBYTES KYBER_CIPHERTEXTBYTES -#define CRYPTO_BYTES KYBER_SSBYTES - -#if (KYBER_K == 2) -#define CRYPTO_ALGNAME "Kyber512" -#elif (KYBER_K == 3) -#define CRYPTO_ALGNAME "Kyber768" -#elif (KYBER_K == 4) -#define CRYPTO_ALGNAME "Kyber1024" -#endif - -#define crypto_kem_keypair_derand KYBER_NAMESPACE(keypair_derand) -int crypto_kem_keypair_derand(uint8_t *pk, uint8_t *sk, const uint8_t *coins); - -#define crypto_kem_keypair KYBER_NAMESPACE(keypair) -int crypto_kem_keypair(uint8_t *pk, uint8_t *sk); - -#define crypto_kem_enc_derand KYBER_NAMESPACE(enc_derand) -int crypto_kem_enc_derand(uint8_t *ct, uint8_t *ss, const uint8_t *pk, const uint8_t *coins); - -#define crypto_kem_enc KYBER_NAMESPACE(enc) -int crypto_kem_enc(uint8_t *ct, uint8_t *ss, const uint8_t *pk); - -#define crypto_kem_dec KYBER_NAMESPACE(dec) -int crypto_kem_dec(uint8_t *ss, const uint8_t *ct, const uint8_t *sk); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.c deleted file mode 100644 index 2f2eb10b2f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.c +++ /dev/null @@ -1,146 +0,0 @@ -#include -#include "params.h" -#include "ntt.h" -#include "reduce.h" - -/* Code to generate zetas and zetas_inv used in the number-theoretic transform: - -#define KYBER_ROOT_OF_UNITY 17 - -static const uint8_t tree[128] = { - 0, 64, 32, 96, 16, 80, 48, 112, 8, 72, 40, 104, 24, 88, 56, 120, - 4, 68, 36, 100, 20, 84, 52, 116, 12, 76, 44, 108, 28, 92, 60, 124, - 2, 66, 34, 98, 18, 82, 50, 114, 10, 74, 42, 106, 26, 90, 58, 122, - 6, 70, 38, 102, 22, 86, 54, 118, 14, 78, 46, 110, 30, 94, 62, 126, - 1, 65, 33, 97, 17, 81, 49, 113, 9, 73, 41, 105, 25, 89, 57, 121, - 5, 69, 37, 101, 21, 85, 53, 117, 13, 77, 45, 109, 29, 93, 61, 125, - 3, 67, 35, 99, 19, 83, 51, 115, 11, 75, 43, 107, 27, 91, 59, 123, - 7, 71, 39, 103, 23, 87, 55, 119, 15, 79, 47, 111, 31, 95, 63, 127 -}; - -void init_ntt() { - unsigned int i; - int16_t tmp[128]; - - tmp[0] = MONT; - for(i=1;i<128;i++) - tmp[i] = fqmul(tmp[i-1],MONT*KYBER_ROOT_OF_UNITY % KYBER_Q); - - for(i=0;i<128;i++) { - zetas[i] = tmp[tree[i]]; - if(zetas[i] > KYBER_Q/2) - zetas[i] -= KYBER_Q; - if(zetas[i] < -KYBER_Q/2) - zetas[i] += KYBER_Q; - } -} -*/ - -const int16_t zetas[128] = { - -1044, -758, -359, -1517, 1493, 1422, 287, 202, - -171, 622, 1577, 182, 962, -1202, -1474, 1468, - 573, -1325, 264, 383, -829, 1458, -1602, -130, - -681, 1017, 732, 608, -1542, 411, -205, -1571, - 1223, 652, -552, 1015, -1293, 1491, -282, -1544, - 516, -8, -320, -666, -1618, -1162, 126, 1469, - -853, -90, -271, 830, 107, -1421, -247, -951, - -398, 961, -1508, -725, 448, -1065, 677, -1275, - -1103, 430, 555, 843, -1251, 871, 1550, 105, - 422, 587, 177, -235, -291, -460, 1574, 1653, - -246, 778, 1159, -147, -777, 1483, -602, 1119, - -1590, 644, -872, 349, 418, 329, -156, -75, - 817, 1097, 603, 610, 1322, -1285, -1465, 384, - -1215, -136, 1218, -1335, -874, 220, -1187, -1659, - -1185, -1530, -1278, 794, -1510, -854, -870, 478, - -108, -308, 996, 991, 958, -1460, 1522, 1628 -}; - -/************************************************* -* Name: fqmul -* -* Description: Multiplication followed by Montgomery reduction -* -* Arguments: - int16_t a: first factor -* - int16_t b: second factor -* -* Returns 16-bit integer congruent to a*b*R^{-1} mod q -**************************************************/ -static int16_t fqmul(int16_t a, int16_t b) { - return montgomery_reduce((int32_t)a*b); -} - -/************************************************* -* Name: ntt -* -* Description: Inplace number-theoretic transform (NTT) in Rq. -* input is in standard order, output is in bitreversed order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void ntt(int16_t r[256]) { - unsigned int len, start, j, k; - int16_t t, zeta; - - k = 1; - for(len = 128; len >= 2; len >>= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k++]; - for(j = start; j < start + len; j++) { - t = fqmul(zeta, r[j + len]); - r[j + len] = r[j] - t; - r[j] = r[j] + t; - } - } - } -} - -/************************************************* -* Name: invntt_tomont -* -* Description: Inplace inverse number-theoretic transform in Rq and -* multiplication by Montgomery factor 2^16. -* Input is in bitreversed order, output is in standard order -* -* Arguments: - int16_t r[256]: pointer to input/output vector of elements of Zq -**************************************************/ -void invntt(int16_t r[256]) { - unsigned int start, len, j, k; - int16_t t, zeta; - const int16_t f = 1441; // mont^2/128 - - k = 127; - for(len = 2; len <= 128; len <<= 1) { - for(start = 0; start < 256; start = j + len) { - zeta = zetas[k--]; - for(j = start; j < start + len; j++) { - t = r[j]; - r[j] = barrett_reduce(t + r[j + len]); - r[j + len] = r[j + len] - t; - r[j + len] = fqmul(zeta, r[j + len]); - } - } - } - - for(j = 0; j < 256; j++) - r[j] = fqmul(r[j], f); -} - -/************************************************* -* Name: basemul -* -* Description: Multiplication of polynomials in Zq[X]/(X^2-zeta) -* used for multiplication of elements in Rq in NTT domain -* -* Arguments: - int16_t r[2]: pointer to the output polynomial -* - const int16_t a[2]: pointer to the first factor -* - const int16_t b[2]: pointer to the second factor -* - int16_t zeta: integer defining the reduction polynomial -**************************************************/ -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta) -{ - r[0] = fqmul(a[1], b[1]); - r[0] = fqmul(r[0], zeta); - r[0] += fqmul(a[0], b[0]); - r[1] = fqmul(a[0], b[1]); - r[1] += fqmul(a[1], b[0]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.h deleted file mode 100644 index 227ea74f08..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/ntt.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef NTT_H -#define NTT_H - -#include -#include "params.h" - -#define zetas KYBER_NAMESPACE(zetas) -extern const int16_t zetas[128]; - -#define ntt KYBER_NAMESPACE(ntt) -void ntt(int16_t poly[256]); - -#define invntt KYBER_NAMESPACE(invntt) -void invntt(int16_t poly[256]); - -#define basemul KYBER_NAMESPACE(basemul) -void basemul(int16_t r[2], const int16_t a[2], const int16_t b[2], int16_t zeta); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/params.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/params.h deleted file mode 100644 index fb4190b311..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/params.h +++ /dev/null @@ -1,55 +0,0 @@ -#ifndef PARAMS_H -#define PARAMS_H - -#ifndef KYBER_K -#define KYBER_K 3 /* Change this for different security strengths */ -#endif - - -/* Don't change parameters below this line */ -#if (KYBER_K == 2) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_512_ref_##s -#elif (KYBER_K == 3) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_768_ref_##s -#elif (KYBER_K == 4) -#define KYBER_NAMESPACE(s) pqcrystals_ml_kem_1024_ref_##s -#else -#error "KYBER_K must be in {2,3,4}" -#endif - -#define KYBER_N 256 -#define KYBER_Q 3329 - -#define KYBER_SYMBYTES 32 /* size in bytes of hashes, and seeds */ -#define KYBER_SSBYTES 32 /* size in bytes of shared key */ - -#define KYBER_POLYBYTES 384 -#define KYBER_POLYVECBYTES (KYBER_K * KYBER_POLYBYTES) - -#if KYBER_K == 2 -#define KYBER_ETA1 3 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 3 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 128 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 320) -#elif KYBER_K == 4 -#define KYBER_ETA1 2 -#define KYBER_POLYCOMPRESSEDBYTES 160 -#define KYBER_POLYVECCOMPRESSEDBYTES (KYBER_K * 352) -#endif - -#define KYBER_ETA2 2 - -#define KYBER_INDCPA_MSGBYTES (KYBER_SYMBYTES) -#define KYBER_INDCPA_PUBLICKEYBYTES (KYBER_POLYVECBYTES + KYBER_SYMBYTES) -#define KYBER_INDCPA_SECRETKEYBYTES (KYBER_POLYVECBYTES) -#define KYBER_INDCPA_BYTES (KYBER_POLYVECCOMPRESSEDBYTES + KYBER_POLYCOMPRESSEDBYTES) - -#define KYBER_PUBLICKEYBYTES (KYBER_INDCPA_PUBLICKEYBYTES) -/* 32 bytes of additional space to save H(pk) */ -#define KYBER_SECRETKEYBYTES (KYBER_INDCPA_SECRETKEYBYTES + KYBER_INDCPA_PUBLICKEYBYTES + 2*KYBER_SYMBYTES) -#define KYBER_CIPHERTEXTBYTES (KYBER_INDCPA_BYTES) - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.c deleted file mode 100644 index cbd3abfb54..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.c +++ /dev/null @@ -1,360 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "ntt.h" -#include "reduce.h" -#include "cbd.h" -#include "symmetric.h" -#include "verify.h" - -/************************************************* -* Name: poly_compress -* -* Description: Compression and subsequent serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (of length KYBER_POLYCOMPRESSEDBYTES) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a) -{ - unsigned int i,j; - int16_t u; - uint32_t d0; - uint8_t t[8]; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint16_t)u << 4) + KYBER_Q/2)/KYBER_Q) & 15; */ - d0 = u << 4; - d0 += 1665; - d0 *= 80635; - d0 >>= 28; - t[j] = d0 & 0xf; - } - - r[0] = t[0] | (t[1] << 4); - r[1] = t[2] | (t[3] << 4); - r[2] = t[4] | (t[5] << 4); - r[3] = t[6] | (t[7] << 4); - r += 4; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - for(i=0;icoeffs[8*i+j]; - u += (u >> 15) & KYBER_Q; -/* t[j] = ((((uint32_t)u << 5) + KYBER_Q/2)/KYBER_Q) & 31; */ - d0 = u << 5; - d0 += 1664; - d0 *= 40318; - d0 >>= 27; - t[j] = d0 & 0x1f; - } - - r[0] = (t[0] >> 0) | (t[1] << 5); - r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7); - r[2] = (t[3] >> 1) | (t[4] << 4); - r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6); - r[4] = (t[6] >> 2) | (t[7] << 3); - r += 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_decompress -* -* Description: De-serialization and subsequent decompression of a polynomial; -* approximate inverse of poly_compress -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYCOMPRESSEDBYTES bytes) -**************************************************/ -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]) -{ - unsigned int i; - -#if (KYBER_POLYCOMPRESSEDBYTES == 128) - for(i=0;icoeffs[2*i+0] = (((uint16_t)(a[0] & 15)*KYBER_Q) + 8) >> 4; - r->coeffs[2*i+1] = (((uint16_t)(a[0] >> 4)*KYBER_Q) + 8) >> 4; - a += 1; - } -#elif (KYBER_POLYCOMPRESSEDBYTES == 160) - unsigned int j; - uint8_t t[8]; - for(i=0;i> 0); - t[1] = (a[0] >> 5) | (a[1] << 3); - t[2] = (a[1] >> 2); - t[3] = (a[1] >> 7) | (a[2] << 1); - t[4] = (a[2] >> 4) | (a[3] << 4); - t[5] = (a[3] >> 1); - t[6] = (a[3] >> 6) | (a[4] << 2); - t[7] = (a[4] >> 3); - a += 5; - - for(j=0;j<8;j++) - r->coeffs[8*i+j] = ((uint32_t)(t[j] & 31)*KYBER_Q + 16) >> 5; - } -#else -#error "KYBER_POLYCOMPRESSEDBYTES needs to be in {128, 160}" -#endif -} - -/************************************************* -* Name: poly_tobytes -* -* Description: Serialization of a polynomial -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYBYTES bytes) -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a) -{ - unsigned int i; - uint16_t t0, t1; - - for(i=0;icoeffs[2*i]; - t0 += ((int16_t)t0 >> 15) & KYBER_Q; - t1 = a->coeffs[2*i+1]; - t1 += ((int16_t)t1 >> 15) & KYBER_Q; - r[3*i+0] = (t0 >> 0); - r[3*i+1] = (t0 >> 8) | (t1 << 4); - r[3*i+2] = (t1 >> 4); - } -} - -/************************************************* -* Name: poly_frombytes -* -* Description: De-serialization of a polynomial; -* inverse of poly_tobytes -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *a: pointer to input byte array -* (of KYBER_POLYBYTES bytes) -**************************************************/ -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) -{ - unsigned int i; - for(i=0;icoeffs[2*i] = ((a[3*i+0] >> 0) | ((uint16_t)a[3*i+1] << 8)) & 0xFFF; - r->coeffs[2*i+1] = ((a[3*i+1] >> 4) | ((uint16_t)a[3*i+2] << 4)) & 0xFFF; - } -} - -/************************************************* -* Name: poly_frommsg -* -* Description: Convert 32-byte message to polynomial -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *msg: pointer to input message -**************************************************/ -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) -{ - unsigned int i,j; - -#if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) -#error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" -#endif - - for(i=0;icoeffs[8*i+j] = 0; - cmov_int16(r->coeffs+8*i+j, ((KYBER_Q+1)/2), (msg[i] >> j)&1); - } - } -} - -/************************************************* -* Name: poly_tomsg -* -* Description: Convert polynomial to 32-byte message -* -* Arguments: - uint8_t *msg: pointer to output message -* - const poly *a: pointer to input polynomial -**************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) -{ - unsigned int i,j; - uint32_t t; - - for(i=0;icoeffs[8*i+j]; - // t += ((int16_t)t >> 15) & KYBER_Q; - // t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1; - t <<= 1; - t += 1665; - t *= 80635; - t >>= 28; - t &= 1; - msg[i] |= t << j; - } - } -} - -/************************************************* -* Name: poly_getnoise_eta1 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA1 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA1*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta1(r, buf); -} - -/************************************************* -* Name: poly_getnoise_eta2 -* -* Description: Sample a polynomial deterministically from a seed and a nonce, -* with output polynomial close to centered binomial distribution -* with parameter KYBER_ETA2 -* -* Arguments: - poly *r: pointer to output polynomial -* - const uint8_t *seed: pointer to input seed -* (of length KYBER_SYMBYTES bytes) -* - uint8_t nonce: one-byte input nonce -**************************************************/ -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t buf[KYBER_ETA2*KYBER_N/4]; - prf(buf, sizeof(buf), seed, nonce); - poly_cbd_eta2(r, buf); -} - - -/************************************************* -* Name: poly_ntt -* -* Description: Computes negacyclic number-theoretic transform (NTT) of -* a polynomial in place; -* inputs assumed to be in normal order, output in bitreversed order -* -* Arguments: - uint16_t *r: pointer to in/output polynomial -**************************************************/ -void poly_ntt(poly *r) -{ - ntt(r->coeffs); - poly_reduce(r); -} - -/************************************************* -* Name: poly_invntt_tomont -* -* Description: Computes inverse of negacyclic number-theoretic transform (NTT) -* of a polynomial in place; -* inputs assumed to be in bitreversed order, output in normal order -* -* Arguments: - uint16_t *a: pointer to in/output polynomial -**************************************************/ -void poly_invntt_tomont(poly *r) -{ - invntt(r->coeffs); -} - -/************************************************* -* Name: poly_basemul_montgomery -* -* Description: Multiplication of two polynomials in NTT domain -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[4*i], &a->coeffs[4*i], &b->coeffs[4*i], zetas[64+i]); - basemul(&r->coeffs[4*i+2], &a->coeffs[4*i+2], &b->coeffs[4*i+2], -zetas[64+i]); - } -} - -/************************************************* -* Name: poly_tomont -* -* Description: Inplace conversion of all coefficients of a polynomial -* from normal domain to Montgomery domain -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_tomont(poly *r) -{ - unsigned int i; - const int16_t f = (1ULL << 32) % KYBER_Q; - for(i=0;icoeffs[i] = montgomery_reduce((int32_t)r->coeffs[i]*f); -} - -/************************************************* -* Name: poly_reduce -* -* Description: Applies Barrett reduction to all coefficients of a polynomial -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - poly *r: pointer to input/output polynomial -**************************************************/ -void poly_reduce(poly *r) -{ - unsigned int i; - for(i=0;icoeffs[i] = barrett_reduce(r->coeffs[i]); -} - -/************************************************* -* Name: poly_add -* -* Description: Add two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_add(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] + b->coeffs[i]; -} - -/************************************************* -* Name: poly_sub -* -* Description: Subtract two polynomials; no modular reduction is performed -* -* Arguments: - poly *r: pointer to output polynomial -* - const poly *a: pointer to first input polynomial -* - const poly *b: pointer to second input polynomial -**************************************************/ -void poly_sub(poly *r, const poly *a, const poly *b) -{ - unsigned int i; - for(i=0;icoeffs[i] = a->coeffs[i] - b->coeffs[i]; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.h deleted file mode 100644 index 9a99c7cdad..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/poly.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef POLY_H -#define POLY_H - -#include -#include "params.h" - -/* - * Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial - * coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] - */ -typedef struct{ - int16_t coeffs[KYBER_N]; -} poly; - -#define poly_compress KYBER_NAMESPACE(poly_compress) -void poly_compress(uint8_t r[KYBER_POLYCOMPRESSEDBYTES], const poly *a); -#define poly_decompress KYBER_NAMESPACE(poly_decompress) -void poly_decompress(poly *r, const uint8_t a[KYBER_POLYCOMPRESSEDBYTES]); - -#define poly_tobytes KYBER_NAMESPACE(poly_tobytes) -void poly_tobytes(uint8_t r[KYBER_POLYBYTES], const poly *a); -#define poly_frombytes KYBER_NAMESPACE(poly_frombytes) -void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]); - -#define poly_frommsg KYBER_NAMESPACE(poly_frommsg) -void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]); -#define poly_tomsg KYBER_NAMESPACE(poly_tomsg) -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r); - -#define poly_getnoise_eta1 KYBER_NAMESPACE(poly_getnoise_eta1) -void poly_getnoise_eta1(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_getnoise_eta2 KYBER_NAMESPACE(poly_getnoise_eta2) -void poly_getnoise_eta2(poly *r, const uint8_t seed[KYBER_SYMBYTES], uint8_t nonce); - -#define poly_ntt KYBER_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); -#define poly_invntt_tomont KYBER_NAMESPACE(poly_invntt_tomont) -void poly_invntt_tomont(poly *r); -#define poly_basemul_montgomery KYBER_NAMESPACE(poly_basemul_montgomery) -void poly_basemul_montgomery(poly *r, const poly *a, const poly *b); -#define poly_tomont KYBER_NAMESPACE(poly_tomont) -void poly_tomont(poly *r); - -#define poly_reduce KYBER_NAMESPACE(poly_reduce) -void poly_reduce(poly *r); - -#define poly_add KYBER_NAMESPACE(poly_add) -void poly_add(poly *r, const poly *a, const poly *b); -#define poly_sub KYBER_NAMESPACE(poly_sub) -void poly_sub(poly *r, const poly *a, const poly *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.c deleted file mode 100644 index 669f6a5f1d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.c +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include "params.h" -#include "poly.h" -#include "polyvec.h" - -/************************************************* -* Name: polyvec_compress -* -* Description: Compress and serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECCOMPRESSEDBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a) -{ - unsigned int i,j,k; - uint64_t d0; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;ivec[i].coeffs[8*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 11) + KYBER_Q/2)/KYBER_Q) & 0x7ff; */ - d0 = t[k]; - d0 <<= 11; - d0 += 1664; - d0 *= 645084; - d0 >>= 31; - t[k] = d0 & 0x7ff; - } - - r[ 0] = (t[0] >> 0); - r[ 1] = (t[0] >> 8) | (t[1] << 3); - r[ 2] = (t[1] >> 5) | (t[2] << 6); - r[ 3] = (t[2] >> 2); - r[ 4] = (t[2] >> 10) | (t[3] << 1); - r[ 5] = (t[3] >> 7) | (t[4] << 4); - r[ 6] = (t[4] >> 4) | (t[5] << 7); - r[ 7] = (t[5] >> 1); - r[ 8] = (t[5] >> 9) | (t[6] << 2); - r[ 9] = (t[6] >> 6) | (t[7] << 5); - r[10] = (t[7] >> 3); - r += 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;ivec[i].coeffs[4*j+k]; - t[k] += ((int16_t)t[k] >> 15) & KYBER_Q; -/* t[k] = ((((uint32_t)t[k] << 10) + KYBER_Q/2)/ KYBER_Q) & 0x3ff; */ - d0 = t[k]; - d0 <<= 10; - d0 += 1665; - d0 *= 1290167; - d0 >>= 32; - t[k] = d0 & 0x3ff; - } - - r[0] = (t[0] >> 0); - r[1] = (t[0] >> 8) | (t[1] << 2); - r[2] = (t[1] >> 6) | (t[2] << 4); - r[3] = (t[2] >> 4) | (t[3] << 6); - r[4] = (t[3] >> 2); - r += 5; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_decompress -* -* Description: De-serialize and decompress vector of polynomials; -* approximate inverse of polyvec_compress -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const uint8_t *a: pointer to input byte array -* (of length KYBER_POLYVECCOMPRESSEDBYTES) -**************************************************/ -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]) -{ - unsigned int i,j,k; - -#if (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 352)) - uint16_t t[8]; - for(i=0;i> 0) | ((uint16_t)a[ 1] << 8); - t[1] = (a[1] >> 3) | ((uint16_t)a[ 2] << 5); - t[2] = (a[2] >> 6) | ((uint16_t)a[ 3] << 2) | ((uint16_t)a[4] << 10); - t[3] = (a[4] >> 1) | ((uint16_t)a[ 5] << 7); - t[4] = (a[5] >> 4) | ((uint16_t)a[ 6] << 4); - t[5] = (a[6] >> 7) | ((uint16_t)a[ 7] << 1) | ((uint16_t)a[8] << 9); - t[6] = (a[8] >> 2) | ((uint16_t)a[ 9] << 6); - t[7] = (a[9] >> 5) | ((uint16_t)a[10] << 3); - a += 11; - - for(k=0;k<8;k++) - r->vec[i].coeffs[8*j+k] = ((uint32_t)(t[k] & 0x7FF)*KYBER_Q + 1024) >> 11; - } - } -#elif (KYBER_POLYVECCOMPRESSEDBYTES == (KYBER_K * 320)) - uint16_t t[4]; - for(i=0;i> 0) | ((uint16_t)a[1] << 8); - t[1] = (a[1] >> 2) | ((uint16_t)a[2] << 6); - t[2] = (a[2] >> 4) | ((uint16_t)a[3] << 4); - t[3] = (a[3] >> 6) | ((uint16_t)a[4] << 2); - a += 5; - - for(k=0;k<4;k++) - r->vec[i].coeffs[4*j+k] = ((uint32_t)(t[k] & 0x3FF)*KYBER_Q + 512) >> 10; - } - } -#else -#error "KYBER_POLYVECCOMPRESSEDBYTES needs to be in {320*KYBER_K, 352*KYBER_K}" -#endif -} - -/************************************************* -* Name: polyvec_tobytes -* -* Description: Serialize vector of polynomials -* -* Arguments: - uint8_t *r: pointer to output byte array -* (needs space for KYBER_POLYVECBYTES) -* - const polyvec *a: pointer to input vector of polynomials -**************************************************/ -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_frombytes -* -* Description: De-serialize vector of polynomials; -* inverse of polyvec_tobytes -* -* Arguments: - uint8_t *r: pointer to output byte array -* - const polyvec *a: pointer to input vector of polynomials -* (of length KYBER_POLYVECBYTES) -**************************************************/ -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]) -{ - unsigned int i; - for(i=0;ivec[i], a+i*KYBER_POLYBYTES); -} - -/************************************************* -* Name: polyvec_ntt -* -* Description: Apply forward NTT to all elements of a vector of polynomials -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_ntt(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_invntt_tomont -* -* Description: Apply inverse NTT to all elements of a vector of polynomials -* and multiply by Montgomery factor 2^16 -* -* Arguments: - polyvec *r: pointer to in/output vector of polynomials -**************************************************/ -void polyvec_invntt_tomont(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_basemul_acc_montgomery -* -* Description: Multiply elements of a and b in NTT domain, accumulate into r, -* and multiply by 2^-16. -* -* Arguments: - poly *r: pointer to output polynomial -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - poly t; - - poly_basemul_montgomery(r, &a->vec[0], &b->vec[0]); - for(i=1;ivec[i], &b->vec[i]); - poly_add(r, r, &t); - } - - poly_reduce(r); -} - -/************************************************* -* Name: polyvec_reduce -* -* Description: Applies Barrett reduction to each coefficient -* of each element of a vector of polynomials; -* for details of the Barrett reduction see comments in reduce.c -* -* Arguments: - polyvec *r: pointer to input/output polynomial -**************************************************/ -void polyvec_reduce(polyvec *r) -{ - unsigned int i; - for(i=0;ivec[i]); -} - -/************************************************* -* Name: polyvec_add -* -* Description: Add vectors of polynomials -* -* Arguments: - polyvec *r: pointer to output vector of polynomials -* - const polyvec *a: pointer to first input vector of polynomials -* - const polyvec *b: pointer to second input vector of polynomials -**************************************************/ -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b) -{ - unsigned int i; - for(i=0;ivec[i], &a->vec[i], &b->vec[i]); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.h deleted file mode 100644 index 57b605494e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/polyvec.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef POLYVEC_H -#define POLYVEC_H - -#include -#include "params.h" -#include "poly.h" - -typedef struct{ - poly vec[KYBER_K]; -} polyvec; - -#define polyvec_compress KYBER_NAMESPACE(polyvec_compress) -void polyvec_compress(uint8_t r[KYBER_POLYVECCOMPRESSEDBYTES], const polyvec *a); -#define polyvec_decompress KYBER_NAMESPACE(polyvec_decompress) -void polyvec_decompress(polyvec *r, const uint8_t a[KYBER_POLYVECCOMPRESSEDBYTES]); - -#define polyvec_tobytes KYBER_NAMESPACE(polyvec_tobytes) -void polyvec_tobytes(uint8_t r[KYBER_POLYVECBYTES], const polyvec *a); -#define polyvec_frombytes KYBER_NAMESPACE(polyvec_frombytes) -void polyvec_frombytes(polyvec *r, const uint8_t a[KYBER_POLYVECBYTES]); - -#define polyvec_ntt KYBER_NAMESPACE(polyvec_ntt) -void polyvec_ntt(polyvec *r); -#define polyvec_invntt_tomont KYBER_NAMESPACE(polyvec_invntt_tomont) -void polyvec_invntt_tomont(polyvec *r); - -#define polyvec_basemul_acc_montgomery KYBER_NAMESPACE(polyvec_basemul_acc_montgomery) -void polyvec_basemul_acc_montgomery(poly *r, const polyvec *a, const polyvec *b); - -#define polyvec_reduce KYBER_NAMESPACE(polyvec_reduce) -void polyvec_reduce(polyvec *r); - -#define polyvec_add KYBER_NAMESPACE(polyvec_add) -void polyvec_add(polyvec *r, const polyvec *a, const polyvec *b); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.c deleted file mode 100644 index 9d8e7edf83..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.c +++ /dev/null @@ -1,42 +0,0 @@ -#include -#include "params.h" -#include "reduce.h" - -/************************************************* -* Name: montgomery_reduce -* -* Description: Montgomery reduction; given a 32-bit integer a, computes -* 16-bit integer congruent to a * R^-1 mod q, where R=2^16 -* -* Arguments: - int32_t a: input integer to be reduced; -* has to be in {-q2^15,...,q2^15-1} -* -* Returns: integer in {-q+1,...,q-1} congruent to a * R^-1 modulo q. -**************************************************/ -int16_t montgomery_reduce(int32_t a) -{ - int16_t t; - - t = (int16_t)a*QINV; - t = (a - (int32_t)t*KYBER_Q) >> 16; - return t; -} - -/************************************************* -* Name: barrett_reduce -* -* Description: Barrett reduction; given a 16-bit integer a, computes -* centered representative congruent to a mod q in {-(q-1)/2,...,(q-1)/2} -* -* Arguments: - int16_t a: input integer to be reduced -* -* Returns: integer in {-(q-1)/2,...,(q-1)/2} congruent to a modulo q. -**************************************************/ -int16_t barrett_reduce(int16_t a) { - int16_t t; - const int16_t v = ((1<<26) + KYBER_Q/2)/KYBER_Q; - - t = ((int32_t)v*a + (1<<25)) >> 26; - t *= KYBER_Q; - return a - t; -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.h deleted file mode 100644 index c1bc1e4c7b..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/reduce.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef REDUCE_H -#define REDUCE_H - -#include -#include "params.h" - -#define MONT -1044 // 2^16 mod q -#define QINV -3327 // q^-1 mod 2^16 - -#define montgomery_reduce KYBER_NAMESPACE(montgomery_reduce) -int16_t montgomery_reduce(int32_t a); - -#define barrett_reduce KYBER_NAMESPACE(barrett_reduce) -int16_t barrett_reduce(int16_t a); - -#endif diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric-shake.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric-shake.c deleted file mode 100644 index 20f451882e..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric-shake.c +++ /dev/null @@ -1,74 +0,0 @@ -#include -#include -#include -#include "params.h" -#include "symmetric.h" -#include "fips202.h" - -/************************************************* -* Name: kyber_shake128_absorb -* -* Description: Absorb step of the SHAKE128 specialized for the Kyber context. -* -* Arguments: - keccak_state *state: pointer to (uninitialized) output Keccak state -* - const uint8_t *seed: pointer to KYBER_SYMBYTES input to be absorbed into state -* - uint8_t i: additional byte of input -* - uint8_t j: additional byte of input -**************************************************/ -void kyber_shake128_absorb(shake128incctx *state, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y) -{ - uint8_t extseed[KYBER_SYMBYTES+2]; - - memcpy(extseed, seed, KYBER_SYMBYTES); - extseed[KYBER_SYMBYTES+0] = x; - extseed[KYBER_SYMBYTES+1] = y; - - shake128_absorb_once(state, extseed, sizeof(extseed)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce) -{ - uint8_t extkey[KYBER_SYMBYTES+1]; - - memcpy(extkey, key, KYBER_SYMBYTES); - extkey[KYBER_SYMBYTES] = nonce; - - shake256(out, outlen, extkey, sizeof(extkey)); -} - -/************************************************* -* Name: kyber_shake256_prf -* -* Description: Usage of SHAKE256 as a PRF, concatenates secret and public input -* and then generates outlen bytes of SHAKE256 output -* -* Arguments: - uint8_t *out: pointer to output -* - size_t outlen: number of requested output bytes -* - const uint8_t *key: pointer to the key (of length KYBER_SYMBYTES) -* - uint8_t nonce: single-byte nonce (public PRF input) -**************************************************/ -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]) -{ - shake256incctx s; - - shake256_inc_init(&s); - shake256_inc_absorb(&s, key, KYBER_SYMBYTES); - shake256_inc_absorb(&s, input, KYBER_CIPHERTEXTBYTES); - shake256_inc_finalize(&s); - shake256_inc_squeeze(out, KYBER_SSBYTES, &s); - shake256_inc_ctx_release(&s); -} diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric.h b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric.h deleted file mode 100644 index 2acc66f98d..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/symmetric.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef SYMMETRIC_H -#define SYMMETRIC_H - -#include -#include -#include "params.h" - -#include "fips202.h" - -typedef shake128incctx xof_state; - -#define kyber_shake128_absorb KYBER_NAMESPACE(kyber_shake128_absorb) -void kyber_shake128_absorb(shake128incctx *s, - const uint8_t seed[KYBER_SYMBYTES], - uint8_t x, - uint8_t y); - -#define kyber_shake256_prf KYBER_NAMESPACE(kyber_shake256_prf) -void kyber_shake256_prf(uint8_t *out, size_t outlen, const uint8_t key[KYBER_SYMBYTES], uint8_t nonce); - -#define kyber_shake256_rkprf KYBER_NAMESPACE(kyber_shake256_rkprf) -void kyber_shake256_rkprf(uint8_t out[KYBER_SSBYTES], const uint8_t key[KYBER_SYMBYTES], const uint8_t input[KYBER_CIPHERTEXTBYTES]); - -#define XOF_BLOCKBYTES SHAKE128_RATE - -#define hash_h(OUT, IN, INBYTES) sha3_256(OUT, IN, INBYTES) -#define hash_g(OUT, IN, INBYTES) sha3_512(OUT, IN, INBYTES) -#define xof_init(STATE, SEED) shake128_inc_init(STATE) -#define xof_absorb(STATE, SEED, X, Y) kyber_shake128_absorb(STATE, SEED, X, Y) -#define xof_squeezeblocks(OUT, OUTBLOCKS, STATE) shake128_squeezeblocks(OUT, OUTBLOCKS, STATE) -#define xof_release(STATE) shake128_inc_ctx_release(STATE) -#define prf(OUT, OUTBYTES, KEY, NONCE) kyber_shake256_prf(OUT, OUTBYTES, KEY, NONCE) -#define rkprf(OUT, KEY, INPUT) kyber_shake256_rkprf(OUT, KEY, INPUT) - -#endif /* SYMMETRIC_H */ diff --git a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/verify.c b/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/verify.c deleted file mode 100644 index 914ccd448f..0000000000 --- a/src/kem/ml_kem/pqcrystals-kyber-standard_ml-kem-768_ref/verify.c +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include -#include "verify.h" - -/************************************************* -* Name: verify -* -* Description: Compare two arrays for equality in constant time. -* -* Arguments: const uint8_t *a: pointer to first byte array -* const uint8_t *b: pointer to second byte array -* size_t len: length of the byte arrays -* -* Returns 0 if the byte arrays are equal, 1 otherwise -**************************************************/ -int verify(const uint8_t *a, const uint8_t *b, size_t len) -{ - size_t i; - uint8_t r = 0; - - for(i=0;i> 63; -} - -/************************************************* -* Name: cmov -* -* Description: Copy len bytes from x to r if b is 1; -* don't modify x if b is 0. Requires b to be in {0,1}; -* assumes two's complement representation of negative integers. -* Runs in constant time. -* -* Arguments: uint8_t *r: pointer to output byte array -* const uint8_t *x: pointer to input byte array -* size_t len: Amount of bytes to be copied -* uint8_t b: Condition bit; has to be in {0,1} -**************************************************/ -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) -{ - size_t i; - -#if defined(__GNUC__) || defined(__clang__) - // Prevent the compiler from - // 1) inferring that b is 0/1-valued, and - // 2) handling the two cases with a branch. - // This is not necessary when verify.c and kem.c are separate translation - // units, but we expect that downstream consumers will copy this code and/or - // change how it is built. - __asm__("" : "+r"(b) : /* no inputs */); -#endif - - b = -b; - for(i=0;i -#include -#include "params.h" - -#define verify KYBER_NAMESPACE(verify) -int verify(const uint8_t *a, const uint8_t *b, size_t len); - -#define cmov KYBER_NAMESPACE(cmov) -void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); - -#define cmov_int16 KYBER_NAMESPACE(cmov_int16) -void cmov_int16(int16_t *r, int16_t v, uint16_t b); - -#endif diff --git a/src/oqsconfig.h.cmake b/src/oqsconfig.h.cmake index 967c35e64e..3c702b83bf 100644 --- a/src/oqsconfig.h.cmake +++ b/src/oqsconfig.h.cmake @@ -128,11 +128,14 @@ #cmakedefine OQS_ENABLE_KEM_ML_KEM 1 #cmakedefine OQS_ENABLE_KEM_ml_kem_512 1 -#cmakedefine OQS_ENABLE_KEM_ml_kem_512_avx2 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_512_x86_64 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_512_aarch64 1 #cmakedefine OQS_ENABLE_KEM_ml_kem_768 1 -#cmakedefine OQS_ENABLE_KEM_ml_kem_768_avx2 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_768_x86_64 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_768_aarch64 1 #cmakedefine OQS_ENABLE_KEM_ml_kem_1024 1 -#cmakedefine OQS_ENABLE_KEM_ml_kem_1024_avx2 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_1024_x86_64 1 +#cmakedefine OQS_ENABLE_KEM_ml_kem_1024_aarch64 1 #cmakedefine OQS_ENABLE_SIG_DILITHIUM 1 #cmakedefine OQS_ENABLE_SIG_dilithium_2 1 diff --git a/tests/test_binary.py b/tests/test_binary.py index 53e114df00..a673e545ae 100644 --- a/tests/test_binary.py +++ b/tests/test_binary.py @@ -33,7 +33,7 @@ def test_namespace(): symbols.append(line) # ideally this would be just ['oqs', 'pqclean'], but contains exceptions (e.g., providing compat implementations of unavailable platform functions) - namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall'] + namespaces = ['oqs', 'pqclean', 'keccak', 'pqcrystals', 'pqmayo', 'init', 'fini', 'seedexpander', '__x86.get_pc_thunk', 'libjade', 'jade', '__jade', '__jasmin_syscall', 'pqcp'] non_namespaced = [] for symbolstr in symbols: