Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize over field variable type #861

Merged
merged 14 commits into from
Jan 8, 2025
53 changes: 25 additions & 28 deletions snarky_integer/integer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module Interval = struct

let iter t ~f = match t with Constant x -> f x | Less_than x -> f x

let check (type f) ~m:((module M) : f m) t =
let check (type f v) ~m:((module M) : (f, v) m) t =
iter t ~f:(fun x -> assert (B.(x < M.Field.size))) ;
t

Expand Down Expand Up @@ -113,18 +113,15 @@ module Interval = struct
end

(* TODO: Use <= instead of < for the upper bound *)
type 'f t =
{ value : 'f Cvar.t
; interval : Interval.t
; mutable bits : 'f Cvar.t Boolean.t list option
}
type ('f, 'v) t =
{ value : 'v; interval : Interval.t; mutable bits : 'v Boolean.t list option }

let create ~value ~upper_bound =
{ value; interval = Less_than upper_bound; bits = None }

let to_field t = t.value

let constant (type f) ?length ~m:((module M) as m : f m) x =
let constant (type f v) ?length ~m:((module M) as m : (f, v) m) x =
let open M in
assert (B.( < ) x Field.size) ;
let upper_bound = B.(one + x) in
Expand All @@ -145,7 +142,7 @@ let constant (type f) ?length ~m:((module M) as m : f m) x =
constant Boolean.typ B.(shift_right x i land one = one) ) )
}

let shift_left (type f) ~m:((module M) as m : f m) t k =
let shift_left (type f v) ~m:((module M) as m : (f, v) m) t k =
let open M in
let two_to_k = B.(one lsl k) in
{ value = Field.(constant (bigint_to_field ~m two_to_k) * t.value)
Expand All @@ -155,7 +152,7 @@ let shift_left (type f) ~m:((module M) as m : f m) t k =
List.init k ~f:(fun _ -> Boolean.false_) @ bs )
}

let of_bits (type f) ~m:((module M) : f m) bs =
let of_bits (type f v) ~m:((module M) : (f, v) m) bs =
let bs = Bitstring.Lsb_first.to_list bs in
{ value = M.Field.project bs
; interval = Less_than B.(one lsl List.length bs)
Expand All @@ -167,7 +164,7 @@ let of_bits (type f) ~m:((module M) : f m) bs =
a = q * b + r
r < b
*)
let div_mod (type f) ~m:((module M) as m : f m) a b =
let div_mod (type f v) ~m:((module M) as m : (f, v) m) a b =
let open M in
(* Guess (q, r) *)
let q, r =
Expand Down Expand Up @@ -198,7 +195,7 @@ let div_mod (type f) ~m:((module M) as m : f m) a b =
}
, { value = r; interval = b.interval; bits = Some r_bits } )

let subtract_unpacking (type f) ~m:((module M) : f m) a b =
let subtract_unpacking (type f v) ~m:((module M) : (f, v) m) a b =
M.with_label "Integer.subtract_unpacking" (fun () ->
assert (Interval.gte a.interval b.interval) ;
let value = M.Field.(sub a.value b.value) in
Expand All @@ -207,15 +204,15 @@ let subtract_unpacking (type f) ~m:((module M) : f m) a b =
let bits = M.Field.unpack value ~length in
{ value; interval = a.interval; bits = Some bits } )

let add (type f) ~m:((module M) as m : f m) a b =
let add (type f v) ~m:((module M) as m : (f, v) m) a b =
let interval = Interval.(add ~m a.interval b.interval) in
{ value = M.Field.(a.value + b.value); interval; bits = None }

let mul (type f) ~m:((module M) as m : f m) a b =
let mul (type f v) ~m:((module M) as m : (f, v) m) a b =
let interval = Interval.(mul ~m a.interval b.interval) in
{ value = M.Field.(a.value * b.value); interval; bits = None }

let to_bits ?length (type f) ~m:((module M) : f m) t =
let to_bits ?length (type f v) ~m:((module M) : (f, v) m) t =
match t.bits with
| Some bs -> (
let bs = Bitstring.Lsb_first.of_list bs in
Expand All @@ -239,7 +236,7 @@ let to_bits_exn t = Bitstring.Lsb_first.of_list (Option.value_exn t.bits)

let to_bits_opt t = Option.map ~f:Bitstring.Lsb_first.of_list t.bits

let min (type f) ~m:((module M) : f m) (a : f t) (b : f t) =
let min (type f v) ~m:((module M) : (f, v) m) (a : (f, v) t) (b : (f, v) t) =
let open M in
let bit_length =
Int.max (Interval.bits_needed a.interval) (Interval.bits_needed b.interval)
Expand All @@ -250,48 +247,48 @@ let min (type f) ~m:((module M) : f m) (a : f t) (b : f t) =
; bits = None
}

let if_ (type f) ~m:((module M) : f m) cond ~then_ ~else_ =
let if_ (type f v) ~m:((module M) : (f, v) m) cond ~then_ ~else_ =
{ value = M.Field.if_ cond ~then_:then_.value ~else_:else_.value
; interval = Interval.lub then_.interval else_.interval
; bits = None
}

let succ_if (type f) ~m:((module M) as m : f m) t (cond : f Cvar.t Boolean.t) =
let succ_if (type f v) ~m:((module M) as m : (f, v) m) t (cond : v Boolean.t) =
let open M in
{ value = Field.(add (cond :> t) t.value)
; interval = Interval.(lub t.interval (succ ~m t.interval))
; bits = None
}

let succ (type f) ~m:((module M) as m : f m) t =
let succ (type f v) ~m:((module M) as m : (f, v) m) t =
let open M in
{ value = Field.(add one t.value)
; interval = Interval.succ ~m t.interval
; bits = None
}

let equal (type f) ~m:((module M) : f m) a b = M.Field.equal a.value b.value
let equal (type f v) ~m:((module M) : (f, v) m) a b =
M.Field.equal a.value b.value

let max_bits a b =
Int.max (Interval.bits_needed a.interval) (Interval.bits_needed b.interval)

let lt (type f) ~m:((module M) : f m) a b =
let lt (type f v) ~m:((module M) : (f, v) m) a b =
(M.Field.compare ~bit_length:(max_bits a b) a.value b.value).less

let lte (type f) ~m:((module M) : f m) a b =
let lte (type f v) ~m:((module M) : (f, v) m) a b =
(M.Field.compare ~bit_length:(max_bits a b) a.value b.value).less_or_equal

let gte (type f) ~m:((module M) as m : f m) a b = M.Boolean.not (lt ~m a b)
let gte (type f v) ~m:((module M) as m : (f, v) m) a b =
M.Boolean.not (lt ~m a b)

let gt (type f) ~m:((module M) as m : f m) a b = M.Boolean.not (lte ~m a b)
let gt (type f v) ~m:((module M) as m : (f, v) m) a b =
M.Boolean.not (lte ~m a b)

let subtract_unpacking_or_zero (type f) ~m:((module M) as m : f m) a b =
let subtract_unpacking_or_zero (type f v) ~m:((module M) as m : (f, v) m) a b =
let flag = lt ~m a b in
( `Underflow flag
, { value =
M.Field.mul
(M.Field.sub a.value b.value)
(M.Boolean.not flag :> f Cvar.t)
, { value = M.Field.mul (M.Field.sub a.value b.value) (M.Boolean.not flag :> v)
; interval = a.interval
; bits = None
} )
63 changes: 36 additions & 27 deletions snarky_integer/integer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,28 @@ module Interval : sig
type t = Constant of B.t | Less_than of B.t
end

type 'f t =
{ value : 'f Cvar.t
; interval : Interval.t
; mutable bits : 'f Cvar.t Boolean.t list option
}
type ('f, 'v) t =
{ value : 'v; interval : Interval.t; mutable bits : 'v Boolean.t list option }

(** Create an value representing the given constant value.

The bit representation of the constant is cached, and is padded to [length]
when given.
*)
val constant : ?length:int -> m:'f m -> Bigint.t -> 'f t
val constant : ?length:int -> m:('f, 'v) m -> Bigint.t -> ('f, 'v) t

(** [shift_left ~m x k] is equivalent to multiplying [x] by [2^k].

The result has a cached bit representation whenever the given [x] had a
cached bit representation.
*)
val shift_left : m:'f m -> 'f t -> int -> 'f t
val shift_left : m:('f, 'v) m -> ('f, 'v) t -> int -> ('f, 'v) t

(** Create a value from the given bit string.

The given bit representation is cached.
*)
val of_bits : m:'f m -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t -> 'f t
val of_bits : m:('f, 'v) m -> 'v Boolean.t Bitstring.Lsb_first.t -> ('f, 'v) t

(** Compute the bit representation of the given integer.

Expand All @@ -51,72 +48,81 @@ val of_bits : m:'f m -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t -> 'f t
value is updated to include the cache.
*)
val to_bits :
?length:int -> m:'f m -> 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t
?length:int
-> m:('f, 'v) m
-> ('f, 'v) t
-> 'v Boolean.t Bitstring.Lsb_first.t

(** Return the cached bit representation, or raise an exception if the bit
representation has not been cached.
*)
val to_bits_exn : 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t
val to_bits_exn : ('f, 'v) t -> 'v Boolean.t Bitstring.Lsb_first.t

(** Returns [Some bs] for [bs] the cached bit representation, or [None] if the
bit representation has not been cached.
*)
val to_bits_opt : 'f t -> 'f Cvar.t Boolean.t Bitstring.Lsb_first.t option
val to_bits_opt : ('f, 'v) t -> 'v Boolean.t Bitstring.Lsb_first.t option

(** [div_mod ~m a b = (q, r)] such that [a = q * b + r] and [r < b].

The bit representations of [q] and [r] are calculated and cached.

NOTE: This uses approximately [log2(a) + 2 * log2(b)] constraints.
*)
val div_mod : m:'f m -> 'f t -> 'f t -> 'f t * 'f t
val div_mod :
m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t * ('f, 'v) t

val to_field : 'f t -> 'f Cvar.t
val to_field : ('f, 'v) t -> 'v

val create : value:'f Cvar.t -> upper_bound:Bigint.t -> 'f t
val create : value:'v -> upper_bound:Bigint.t -> ('f, 'v) t

(** [min ~m x y] returns a value equal the lesser of [x] and [y].

The result does not carry a cached bit representation.
*)
val min : m:'f m -> 'f t -> 'f t -> 'f t
val min : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

val if_ : m:'f m -> 'f Cvar.t Boolean.t -> then_:'f t -> else_:'f t -> 'f t
val if_ :
m:('f, 'v) m
-> 'v Boolean.t
-> then_:('f, 'v) t
-> else_:('f, 'v) t
-> ('f, 'v) t

(** [succ ~m x] computes the successor [x+1] of [x].

The result does not carry a cached bit representation.
*)
val succ : m:'f m -> 'f t -> 'f t
val succ : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t

(** [succ_if ~m x b] computes the integer [x+1] if [b] is [true], or [x]
otherwise.

The result does not carry a cached bit representation.
*)
val succ_if : m:'f m -> 'f t -> 'f Cvar.t Boolean.t -> 'f t
val succ_if : m:('f, 'v) m -> ('f, 'v) t -> 'v Boolean.t -> ('f, 'v) t

val equal : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val equal : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val lt : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val lt : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val lte : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val lte : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val gt : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val gt : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

val gte : m:'f m -> 'f t -> 'f t -> 'f Cvar.t Boolean.t
val gte : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> 'v Boolean.t

(** [add ~m x y] computes [x + y].

The result does not carry a cached bit representation.
*)
val add : m:'f m -> 'f t -> 'f t -> 'f t
val add : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [mul ~m x y] computes [x * y].

The result does not carry a cached bit representation.
*)
val mul : m:'f m -> 'f t -> 'f t -> 'f t
val mul : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [subtract_unpacking ~m x y] computes [x - y].

Expand All @@ -125,7 +131,7 @@ val mul : m:'f m -> 'f t -> 'f t -> 'f t

NOTE: This uses approximately [log2(x)] constraints.
*)
val subtract_unpacking : m:'f m -> 'f t -> 'f t -> 'f t
val subtract_unpacking : m:('f, 'v) m -> ('f, 'v) t -> ('f, 'v) t -> ('f, 'v) t

(** [subtract_unpacking_or_zero ~m x y] computes [x - y].

Expand All @@ -139,4 +145,7 @@ val subtract_unpacking : m:'f m -> 'f t -> 'f t -> 'f t
NOTE: This uses approximately [log2(x)] constraints.
*)
val subtract_unpacking_or_zero :
m:'f m -> 'f t -> 'f t -> [ `Underflow of 'f Cvar.t Boolean.t ] * 'f t
m:('f, 'v) m
-> ('f, 'v) t
-> ('f, 'v) t
-> [ `Underflow of 'v Boolean.t ] * ('f, 'v) t
4 changes: 2 additions & 2 deletions snarky_integer/util.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ open Snarky_backendless
open Snark
module B = Bigint

let bigint_to_field (type f) ~m:((module M) : f m) =
let bigint_to_field (type f v) ~m:((module M) : (f, v) m) =
let open M in
Fn.compose Bigint.to_field Bigint.of_bignum_bigint

let bigint_of_field (type f) ~m:((module M) : f m) =
let bigint_of_field (type f v) ~m:((module M) : (f, v) m) =
let open M in
Fn.compose Bigint.to_bignum_bigint Bigint.of_field
18 changes: 6 additions & 12 deletions src/base/as_prover0.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ module Make (Backend : sig
module Field : sig
type t
end

module Cvar : sig
type t
end
end)
(Types : Types.Types
with type field = Backend.Field.t
and type field_var = Backend.Field.t Cvar.t
and type 'a As_prover.t =
(Backend.Field.t Cvar.t -> Backend.Field.t) -> 'a) =
and type field_var = Backend.Cvar.t
and type 'a As_prover.t = (Backend.Cvar.t -> Backend.Field.t) -> 'a) =
struct
type 'a t = 'a Types.As_prover.t

Expand Down Expand Up @@ -80,12 +83,3 @@ struct
fun _ -> Option.value_exn t.value
end
end

module Make_extended (Env : sig
type field
end)
(As_prover : As_prover_intf.Basic with type field := Env.field) =
struct
include Env
include As_prover
end
9 changes: 4 additions & 5 deletions src/base/as_prover_intf.ml
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
module type Basic = sig
module Types : Types.Types

type field

type 'a t = 'a Types.As_prover.t

include Monad_let.S with type 'a t := 'a t

val run : 'a t -> (field Cvar.t -> field) -> 'a
val run : 'a t -> (Types.field_var -> Types.field) -> 'a

val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t

val read_var : field Cvar.t -> field t
val read_var : Types.field_var -> Types.field t

val read : ('var, 'value) Types.Typ.t -> 'var -> 'value t

module Provider : sig
type 'a t

val run : 'a t -> (field Cvar.t -> field) -> Request.Handler.t -> 'a option
val run :
'a t -> (Types.field_var -> Types.field) -> Request.Handler.t -> 'a option
end

module Handle : sig
Expand Down
Loading
Loading