Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize over field variable type #861

Merged
merged 14 commits into from
Jan 8, 2025
18 changes: 6 additions & 12 deletions src/base/as_prover0.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ module Make (Backend : sig
module Field : sig
type t
end

module Cvar : sig
type t
end
end)
(Types : Types.Types
with type field = Backend.Field.t
and type field_var = Backend.Field.t Cvar.t
and type 'a As_prover.t =
(Backend.Field.t Cvar.t -> Backend.Field.t) -> 'a) =
and type field_var = Backend.Cvar.t
and type 'a As_prover.t = (Backend.Cvar.t -> Backend.Field.t) -> 'a) =
struct
type 'a t = 'a Types.As_prover.t

Expand Down Expand Up @@ -80,12 +83,3 @@ struct
fun _ -> Option.value_exn t.value
end
end

module Make_extended (Env : sig
type field
end)
(As_prover : As_prover_intf.Basic with type field := Env.field) =
struct
include Env
include As_prover
end
9 changes: 4 additions & 5 deletions src/base/as_prover_intf.ml
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
module type Basic = sig
module Types : Types.Types

type field

type 'a t = 'a Types.As_prover.t

include Monad_let.S with type 'a t := 'a t

val run : 'a t -> (field Cvar.t -> field) -> 'a
val run : 'a t -> (Types.field_var -> Types.field) -> 'a

val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t

val read_var : field Cvar.t -> field t
val read_var : Types.field_var -> Types.field t

val read : ('var, 'value) Types.Typ.t -> 'var -> 'value t

module Provider : sig
type 'a t

val run : 'a t -> (field Cvar.t -> field) -> Request.Handler.t -> 'a option
val run :
'a t -> (Types.field_var -> Types.field) -> Request.Handler.t -> 'a option
end

module Handle : sig
Expand Down
58 changes: 6 additions & 52 deletions src/base/backend_extended.ml
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,10 @@ module type S = sig
val to_bignum_bigint : t -> Bignum_bigint.t
end

module Cvar : sig
type t = Field.t Cvar.t [@@deriving sexp]

val length : t -> int

module Unsafe : sig
val of_index : int -> t
end

val eval :
[ `Return_values_will_be_mutated of int -> Field.t ] -> t -> Field.t

val constant : Field.t -> t

val to_constant_and_terms : t -> Field.t option * (Field.t * int) list

val add : t -> t -> t

val negate : t -> t

val scale : t -> Field.t -> t

val sub : t -> t -> t

val linear_combination : (Field.t * t) list -> t

val sum : t list -> t

val ( + ) : t -> t -> t

val ( - ) : t -> t -> t

val ( * ) : Field.t -> t -> t

val var_indices : t -> int list

val to_constant : t -> Field.t option
end
module Cvar :
Backend_intf.Cvar_intf
with type field := Field.t
and type t = Field.t Cvar.t

module Constraint : sig
type t [@@deriving sexp]
Expand Down Expand Up @@ -91,6 +57,7 @@ module Make (Backend : Backend_intf.S) :
with type Field.t = Backend.Field.t
and type Field.Vector.t = Backend.Field.Vector.t
and type Bigint.t = Backend.Bigint.t
and type Cvar.t = Backend.Cvar.t
and type R1CS_constraint_system.t = Backend.R1CS_constraint_system.t
and type Run_state.t = Backend.Run_state.t
and type Constraint.t = Backend.Constraint.t = struct
Expand Down Expand Up @@ -195,20 +162,7 @@ module Make (Backend : Backend_intf.S) :
let ( / ) = div
end

module Cvar = struct
include Cvar.Make (Field)

let var_indices t =
let _, terms = to_constant_and_terms t in
List.map ~f:(fun (_, v) -> v) terms

let to_constant : t -> Field.t option = function
| Constant x ->
Some x
| _ ->
None
end

module Cvar = Cvar
module Constraint = Constraint
module R1CS_constraint_system = R1CS_constraint_system
module Run_state = Run_state
Expand Down
54 changes: 48 additions & 6 deletions src/base/backend_intf.ml
Original file line number Diff line number Diff line change
@@ -1,24 +1,66 @@
module type Cvar_intf = sig
type field

type t [@@deriving sexp]

val length : t -> int

module Unsafe : sig
val of_index : int -> t
end

val eval : [ `Return_values_will_be_mutated of int -> field ] -> t -> field

val constant : field -> t

val to_constant_and_terms : t -> field option * (field * int) list

val add : t -> t -> t

val negate : t -> t

val scale : t -> field -> t

val sub : t -> t -> t

val linear_combination : (field * t) list -> t

val sum : t list -> t

val ( + ) : t -> t -> t

val ( - ) : t -> t -> t

val ( * ) : field -> t -> t

val var_indices : t -> int list

val to_constant : t -> field option
end

module type S = sig
module Field : Snarky_intf.Field.S

module Bigint : Snarky_intf.Bigint_intf.Extended with type field := Field.t

val field_size : Bigint.t

module Cvar : Cvar_intf with type field := Field.t and type t = Field.t Cvar.t

module Constraint : sig
type t [@@deriving sexp]

val boolean : Field.t Cvar.t -> t
val boolean : Cvar.t -> t

val equal : Field.t Cvar.t -> Field.t Cvar.t -> t
val equal : Cvar.t -> Cvar.t -> t

val r1cs : Field.t Cvar.t -> Field.t Cvar.t -> Field.t Cvar.t -> t
val r1cs : Cvar.t -> Cvar.t -> Cvar.t -> t

val square : Field.t Cvar.t -> Field.t Cvar.t -> t
val square : Cvar.t -> Cvar.t -> t

val eval : t -> (Field.t Cvar.t -> Field.t) -> bool
val eval : t -> (Cvar.t -> Field.t) -> bool

val log_constraint : t -> (Field.t Cvar.t -> Field.t) -> string
val log_constraint : t -> (Cvar.t -> Field.t) -> string
end

module R1CS_constraint_system :
Expand Down
16 changes: 6 additions & 10 deletions src/base/checked.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@ open Core_kernel

module Make
(Backend : Backend_extended.S)
(Types : Types.Types)
(Types : Types.Types with type field_var = Backend.Cvar.t)
(Basic : Checked_intf.Basic
with type field = Backend.Field.t
and type constraint_ = Backend.Constraint.t
with type constraint_ = Backend.Constraint.t
with module Types := Types)
(As_prover : As_prover_intf.Basic
with type field := Basic.field
with module Types := Types) :
(As_prover : As_prover_intf.Basic with module Types := Types) :
Checked_intf.S
with module Types := Types
with type field = Backend.Field.t
and type run_state = Basic.run_state
with type run_state = Basic.run_state
and type constraint_ = Basic.constraint_ = struct
include Basic

Expand Down Expand Up @@ -81,8 +77,8 @@ module Make
bind acc ~f:(fun () -> add_constraint c) )

let assert_equal x y =
match (x, y) with
| Cvar.Constant x, Cvar.Constant y ->
match (Backend.Cvar.to_constant x, Backend.Cvar.to_constant y) with
| Some x, Some y ->
if Backend.Field.equal x y then return ()
else
failwithf
Expand Down
11 changes: 4 additions & 7 deletions src/base/checked_intf.ml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module type Basic = sig
module Types : Types.Types

type field

type constraint_

type 'a t = 'a Types.Checked.t
Expand Down Expand Up @@ -40,8 +38,6 @@ end
module type S = sig
module Types : Types.Types

type field

type constraint_

type run_state
Expand Down Expand Up @@ -95,13 +91,14 @@ module type S = sig

val assert_ : constraint_ -> unit t

val assert_r1cs : field Cvar.t -> field Cvar.t -> field Cvar.t -> unit t
val assert_r1cs :
Types.field_var -> Types.field_var -> Types.field_var -> unit t

val assert_square : field Cvar.t -> field Cvar.t -> unit t
val assert_square : Types.field_var -> Types.field_var -> unit t

val assert_all : constraint_ list -> unit t

val assert_equal : field Cvar.t -> field Cvar.t -> unit t
val assert_equal : Types.field_var -> Types.field_var -> unit t

val direct : (run_state -> run_state * 'a) -> 'a t

Expand Down
27 changes: 8 additions & 19 deletions src/base/checked_runner.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end
module Simple_types (Backend : Backend_extended.S) = Types.Make_types (struct
type field = Backend.Field.t

type field_var = field Cvar.t
type field_var = Backend.Cvar.t

type 'a checked = 'a T(Backend).t

Expand All @@ -26,16 +26,14 @@ module Make_checked
(Backend : Backend_extended.S)
(Types : Types.Types
with type field = Backend.Field.t
and type field_var = Backend.Field.t Cvar.t
and type field_var = Backend.Cvar.t
and type 'a Checked.t = 'a Simple_types(Backend).Checked.t
and type 'a As_prover.t = 'a Simple_types(Backend).As_prover.t
and type ('var, 'value, 'aux) Typ.typ' =
('var, 'value, 'aux) Simple_types(Backend).Typ.typ'
and type ('var, 'value) Typ.typ =
('var, 'value) Simple_types(Backend).Typ.typ)
(As_prover : As_prover_intf.Basic
with type field := Backend.Field.t
with module Types := Types) =
(As_prover : As_prover_intf.Basic with module Types := Types) =
struct
type run_state = Backend.Run_state.t

Expand Down Expand Up @@ -265,15 +263,11 @@ struct
end

module type Run_extras = sig
type field

type cvar

type run_state

module Types : Types.Types

val get_value : run_state -> cvar -> field
val get_value : run_state -> Types.field_var -> Types.field

val run_as_prover :
'a Types.As_prover.t option -> run_state -> run_state * 'a option
Expand Down Expand Up @@ -318,15 +312,10 @@ struct
include
Checked_intf.Basic
with module Types := Types
with type field := Checked_runner.field
and type run_state := run_state
with type run_state := run_state

include
Run_extras
with module Types := Types
with type field := Backend.Field.t
and type cvar := Backend.Cvar.t
and type run_state := run_state
Run_extras with module Types := Types with type run_state := run_state
end )

let run = Checked_runner.eval
Expand Down Expand Up @@ -392,9 +381,9 @@ module type S = sig
module State : sig
val make :
num_inputs:int
-> input:field Run_state_intf.Vector.t
-> input:Types.field Run_state_intf.Vector.t
-> next_auxiliary:int ref
-> aux:field Run_state_intf.Vector.t
-> aux:Types.field Run_state_intf.Vector.t
-> ?system:r1cs
-> ?eval_constraints:bool
-> ?handler:Request.Handler.t
Expand Down
10 changes: 10 additions & 0 deletions src/base/cvar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,14 @@ module Make (Field : Snarky_intf.Field.Extended) = struct
(List.filter_map (Map.to_alist map) ~f:(fun (i, f) ->
if Field.(equal f zero) then None
else Some (Int.to_string i, `String (Field.to_string f)) ) )

let var_indices t =
let _, terms = to_constant_and_terms t in
List.map ~f:(fun (_, v) -> v) terms

let to_constant : t -> Field.t option = function
| Constant x ->
Some x
| _ ->
None
end
11 changes: 3 additions & 8 deletions src/base/runners.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@ module Make
and type field_var = Backend.Cvar.t)
(Checked : Checked_intf.Extended
with module Types := Types
with type field = Backend.Field.t
and type run_state = Backend.Run_state.t
with type run_state = Backend.Run_state.t
and type constraint_ = Backend.Constraint.t)
(As_prover : As_prover_intf.Basic
with type field := Backend.Field.t
with module Types := Types)
(As_prover : As_prover_intf.Basic with module Types := Types)
(Runner : Checked_runner.S
with module Types := Types
with type field := Backend.Field.t
and type cvar := Backend.Cvar.t
and type constr := Backend.Constraint.t option
with type constr := Backend.Constraint.t option
and type r1cs := Backend.R1CS_constraint_system.t
and type run_state = Backend.Run_state.t) =
struct
Expand Down
Loading