diff --git a/src/lib/crypto/kimchi_backend/common/dune b/src/lib/crypto/kimchi_backend/common/dune index 3d75a75b7ca..9a7cb8b8c91 100644 --- a/src/lib/crypto/kimchi_backend/common/dune +++ b/src/lib/crypto/kimchi_backend/common/dune @@ -25,6 +25,7 @@ base.caml ppx_inline_test.config bignum.bigint + zarith base.base_internalhash_types ;; local libraries tuple_lib diff --git a/src/lib/crypto/kimchi_backend/common/field.ml b/src/lib/crypto/kimchi_backend/common/field.ml index 91d76aa71ff..12261d7b9f5 100644 --- a/src/lib/crypto/kimchi_backend/common/field.ml +++ b/src/lib/crypto/kimchi_backend/common/field.ml @@ -37,6 +37,8 @@ module type Input_intf = sig val is_square : t -> bool + val compare : t -> t -> int + val equal : t -> t -> bool val print : t -> unit @@ -177,24 +179,22 @@ module Make (F : Input_intf) : let of_sexpable = of_bigint end) - let to_bignum_bigint n = - let rec go i two_to_the_i acc = - if Int.equal i size_in_bits then acc + let to_bignum_bigint = + let zero = of_int 0 in + let one = of_int 1 in + fun n -> + if equal n zero then Bignum_bigint.zero + else if equal n one then Bignum_bigint.one else - let acc' = - if Bigint.test_bit n i then Bignum_bigint.(acc + two_to_the_i) - else acc - in - go (i + 1) Bignum_bigint.(two_to_the_i + two_to_the_i) acc' - in - go 0 Bignum_bigint.one Bignum_bigint.zero + Bytes.unsafe_to_string + ~no_mutation_while_string_reachable:(to_bytes n) + |> Z.of_bits |> Bignum_bigint.of_zarith_bigint - let hash_fold_t s x = - Bignum_bigint.hash_fold_t s (to_bignum_bigint (to_bigint x)) + let hash_fold_t s x = Bignum_bigint.hash_fold_t s (to_bignum_bigint x) let hash = Hash.of_fold hash_fold_t - let compare t1 t2 = Bigint.compare (to_bigint t1) (to_bigint t2) + let compare = compare let equal = equal diff --git a/src/lib/crypto/kimchi_backend/kimchi_backend.mli b/src/lib/crypto/kimchi_backend/kimchi_backend.mli index b2fb42a082b..eb9baab3a9d 100644 --- a/src/lib/crypto/kimchi_backend/kimchi_backend.mli +++ b/src/lib/crypto/kimchi_backend/kimchi_backend.mli @@ -11,8 +11,6 @@ module Kimchi_backend_common : sig val sexp_of_t : t -> Sexplib0.Sexp.t - val compare : t -> t -> int - val bin_size_t : t Bin_prot.Size.sizer val bin_write_t : t Bin_prot.Write.writer @@ -56,6 +54,8 @@ module Kimchi_backend_common : sig val is_square : t -> bool + val compare : t -> t -> int + val equal : t -> t -> bool val print : t -> unit diff --git a/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fp.rs b/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fp.rs index 3d9144f1303..08264921088 100644 --- a/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fp.rs +++ b/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fp.rs @@ -31,7 +31,7 @@ impl CamlFp { unsafe extern "C" fn ocaml_compare(x: ocaml::Raw, y: ocaml::Raw) -> i32 { let x = x.as_pointer::(); let y = y.as_pointer::(); - match x.as_ref().0.cmp(&y.as_ref().0) { + match x.as_ref().0.into_repr().cmp(&y.as_ref().0.into_repr()) { core::cmp::Ordering::Less => -1, core::cmp::Ordering::Equal => 0, core::cmp::Ordering::Greater => 1, @@ -240,7 +240,7 @@ pub fn caml_pasta_fp_mut_square(mut x: ocaml::Pointer) { #[ocaml_gen::func] #[ocaml::func] pub fn caml_pasta_fp_compare(x: ocaml::Pointer, y: ocaml::Pointer) -> ocaml::Int { - match x.as_ref().0.cmp(&y.as_ref().0) { + match x.as_ref().0.into_repr().cmp(&y.as_ref().0.into_repr()) { Less => -1, Equal => 0, Greater => 1, diff --git a/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fq.rs b/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fq.rs index 8fbca9da595..bc81f5962a6 100644 --- a/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fq.rs +++ b/src/lib/crypto/kimchi_bindings/stubs/src/arkworks/pasta_fq.rs @@ -36,7 +36,7 @@ impl CamlFq { unsafe extern "C" fn ocaml_compare(x: ocaml::Raw, y: ocaml::Raw) -> i32 { let x = x.as_pointer::(); let y = y.as_pointer::(); - match x.as_ref().0.cmp(&y.as_ref().0) { + match x.as_ref().0.into_repr().cmp(&y.as_ref().0.into_repr()) { core::cmp::Ordering::Less => -1, core::cmp::Ordering::Equal => 0, core::cmp::Ordering::Greater => 1, @@ -241,7 +241,7 @@ pub fn caml_pasta_fq_mut_square(mut x: ocaml::Pointer) { #[ocaml_gen::func] #[ocaml::func] pub fn caml_pasta_fq_compare(x: ocaml::Pointer, y: ocaml::Pointer) -> ocaml::Int { - match x.as_ref().0.cmp(&y.as_ref().0) { + match x.as_ref().0.into_repr().cmp(&y.as_ref().0.into_repr()) { Less => -1, Equal => 0, Greater => 1,