Skip to content

Commit

Permalink
editing...
Browse files Browse the repository at this point in the history
  • Loading branch information
jzstark committed Aug 31, 2024
1 parent a839af0 commit af695f9
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 137 deletions.
39 changes: 18 additions & 21 deletions src/base/algodiff/owl_algodiff_core_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,46 @@ module type Sig = sig
(** {5 Core functions} *)

val tag : unit -> int
(** TODO *)
(** start global tagging counter *)

val primal : t -> t
(** TODO *)
(** get primal component of DF or DR type *)

val primal' : t -> t
(** TODO *)
(** iteratively get primal component of DF or DR type until the component itself is not DF/DR *)

val zero : t -> t
(** TODO *)
(** return a zero value, which type decided by the input value *)

val reset_zero : t -> t
(** TODO *)
(** [reset_zero x] iteratively resets all elements included in [x] *)

val tangent : t -> t
(** TODO *)
(** get the tangent component of input, if the data type is suitable *)

val adjref : t -> t ref
(** TODO *)
(** get the adjref component of input, if the data type is suitable *)

val adjval : t -> t
(** TODO *)
(** get the adjval component of input, if the data type is suitableTODO *)

val shape : t -> int array
(** TODO *)
(** get the shape of primal' value of input *)

val is_float : t -> bool
(** TODO *)
(** check if input is of float value; if input is of type DF/DR, check its primal' value *)

val is_arr : t -> bool
(** TODO *)
(** check if input is of ndarray value; if input is of type DF/DR, check its primal' value *)

val row_num : t -> int
(** number of rows *)
(** get the shape of primal' value of input; and then get the first dimension *)

val col_num : t -> int
(** number of columns *)
(** get the shape of primal' value of input; and then get the second dimension *)

val numel : t -> int
(** number of elements *)
(** for ndarray type input, return its total number of elements. *)

val clip_by_value : amin:A.elt -> amax:A.elt -> t -> t
(** other functions, without tracking gradient *)
Expand All @@ -61,13 +61,13 @@ module type Sig = sig
(** other functions, without tracking gradient *)

val copy_primal' : t -> t
(** TODO *)
(** if primal' value of input is ndarray, copy its value in a new AD type ndarray *)

val tile : t -> int array -> t
(** TODO *)
(** if primal' value of input is ndarray, apply the tile function *)

val repeat : t -> int array -> t
(** TODO *)
(** if primal' value of input is ndarray, apply the repeat function *)

val pack_elt : A.elt -> t
(** convert from [elt] type to [t] type. *)
Expand All @@ -93,14 +93,11 @@ module type Sig = sig
(* functions to report errors, help in debugging *)

val deep_info : t -> string
(** TODO *)

val type_info : t -> string
(** TODO *)

val error_binop : string -> t -> t -> 'a
(** TODO *)

val error_uniop : string -> t -> 'a
(** TODO *)

end
34 changes: 14 additions & 20 deletions src/base/compute/owl_computation_graph_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,26 @@ module type Sig = sig
(** {5 Type definition} *)

type graph
(** TODO *)

(** {5 Core functions} *)

val shape_or_value : t -> string
(** TODO *)
(** print shape for ndarrays, whilst value for scalars *)

val graph_to_dot : graph -> string
(** TODO *)
(** generate a string that can be written to a .dot file to draw the graph *)

val graph_to_trace : graph -> string
(** TODO *)
(** print the graph structure to a string *)

val save_graph : 'a -> string -> unit
(** TODO *)
(** save the graph object to a file with given name, using marshall format *)

val load_graph : string -> 'a * 'b
(** TODO *)
(** load the graph object from a file with given name *)

val collect_rvs : attr Owl_graph.node array -> attr Owl_graph.node array
(** TODO *)
(** traverse each node in the input array, and return the random variable type nodes. *)

val invalidate_rvs : graph -> unit
(** TODO *)
Expand All @@ -42,48 +41,43 @@ module type Sig = sig
-> output:attr Owl_graph.node array
-> string
-> graph
(** TODO *)
(** Build a graph based on input nodes, output nodes, and graph name *)

val get_inputs : graph -> attr Owl_graph.node array
(** TODO *)
(** get input nodes of a graph *)

val get_outputs : graph -> attr Owl_graph.node array
(** TODO *)
(** get output nodes of a graph *)

val get_node_arr_val : attr Owl_graph.node -> A.arr
(** TODO *)

val get_node_elt_val : attr Owl_graph.node -> A.elt
(** TODO *)

val set_node_arr_val : attr Owl_graph.node -> value -> unit
(** TODO *)

val set_node_elt_val : attr Owl_graph.node -> value -> unit
(** TODO *)

val is_iopair_safe : 'a Owl_graph.node -> 'a Owl_graph.node -> bool
(** TODO *)

val make_iopair
: graph
-> attr Owl_graph.node array
-> attr Owl_graph.node array
-> unit
(** TODO *)
(** create an iopair between the input nodes and output nodes in a graph *)

val update_iopair : graph -> unit
(** TODO *)

val remove_unused_iopair
: 'a Owl_graph.node array
-> 'b array
-> 'a Owl_graph.node array * 'b array
(** TODO *)
(** remove unuserd iopair from an array of nodes *)

val init_inputs : (attr Owl_graph.node -> value) -> graph -> unit
(** TODO *)
(** initialize inputs nodes of a graph with given function [f] *)

val optimise : graph -> unit
(** TODO *)
(** optimise the graph structure *)

end
94 changes: 47 additions & 47 deletions src/base/compute/owl_computation_symbol_sig.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,43 @@ module type Sig = sig
(** {5 Core functions} *)

val op_to_str : op -> string
(** TODO *)
(** return the name of the operator as string *)

val is_random_variable : op -> bool
(** TODO *)
(** check if operator is randon variable *)

val refnum : 'a Owl_graph.node -> int
(** TODO *)
(** return the reference number of the given node *)

val node_shape : attr Owl_graph.node -> int array
(** TODO *)
(** return the shape of a node *)

val node_numel : attr Owl_graph.node -> int
(** TODO *)
(** return the number of elements of a node *)

val is_shape_unknown : attr Owl_graph.node -> bool
(** TODO *)
(** check if the shape of the input node is unknown *)

val infer_shape_graph : attr Owl_graph.node array -> unit
(** TODO *)
(** automatically infer the shape of input node according to its descendents' shapes *)

val shape_to_str : int array option array -> string
(** TODO *)
(** helper function; return the input array in string format. *)

val node_to_str : attr Owl_graph.node -> string
(** TODO *)
(** print node's information to string *)

val node_to_arr : t -> arr
(** TODO *)
(** Wrap computation graph node in an array type *)

val arr_to_node : arr -> t
(** TODO *)
(** Unwrap the array type to get the computation graph node within *)

val node_to_elt : t -> elt
(** TODO *)
(** Wrap computation graph node in an Elt type *)

val elt_to_node : elt -> t
(** TODO *)
(** Unwrap the Elt type to get the computation graph node within *)

val make_node
: ?name:string
Expand All @@ -62,26 +62,26 @@ module type Sig = sig
-> ?state:state
-> op
-> attr Owl_graph.node
(** TODO *)
(** crate a computation graph node *)

val make_then_connect
: ?shape:int array option array
-> op
-> attr Owl_graph.node array
-> attr Owl_graph.node
(** TODO *)
(** make nodes and then connect parents and children *)

val var_arr : ?shape:int array -> string -> arr
(** TODO *)
(** creat a node and wrap in Arr type *)

val var_elt : string -> elt
(** TODO *)
(** creat a node and wrap in Elt type *)

val const_arr : string -> A.arr -> arr
(** TODO *)
(** get ndarray value from input and create an node and wrap in Arr type *)

val const_elt : string -> A.elt -> elt
(** TODO *)
(** get value from input and create an node and wrap in Elt type *)

val new_block_id : unit -> int
(** [new_block_id ()] returns an unused block id. *)
Expand Down Expand Up @@ -120,25 +120,25 @@ module type Sig = sig
*)

val set_value : attr Owl_graph.node -> value array -> unit
(** TODO *)
(** set the arrays of value to cgraph node *)

val get_value : attr Owl_graph.node -> value array
(** TODO *)
(** get the arrays of value of cgraph node *)

val set_operator : attr Owl_graph.node -> op -> unit
(** TODO *)
(** set the operator of cgraph node *)

val get_operator : attr Owl_graph.node -> op
(** TODO *)
(** get the operator of cgraph node *)

val set_reuse : attr Owl_graph.node -> bool -> unit
(** TODO *)
(** set reuse attribute in a node *)

val get_reuse : attr Owl_graph.node -> bool
(** TODO *)
(** get reuse attribute in a node *)

val is_shared : attr Owl_graph.node -> bool
(** TODO *)
(** check of the data block of memory is shared in a node *)

val get_shared_nodes : attr Owl_graph.node -> attr Owl_graph.node array
(**
Expand All @@ -147,16 +147,16 @@ module type Sig = sig
*)

val is_var : attr Owl_graph.node -> bool
(** TODO *)
(** check if the node's operator is Var type *)

val is_const : attr Owl_graph.node -> bool
(** TODO *)
(** check if the node's operator is Const type *)

val is_node_arr : attr Owl_graph.node -> bool
(** TODO *)
(** check the shape of a node's attr and return if it indicates an ndarray *)

val is_node_elt : attr Owl_graph.node -> bool
(** TODO *)
(** check the shape of a node's attr and return if it indicates an elt *)

val is_assigned : attr Owl_graph.node -> bool
(**
Expand All @@ -171,53 +171,53 @@ module type Sig = sig
*)

val is_valid : attr Owl_graph.node -> bool
(** TODO *)
(** check if the state attribute of a node is Valid *)

val validate : attr Owl_graph.node -> unit
(** TODO *)
(** set Valid to the state attribute of a node *)

val invalidate : attr Owl_graph.node -> unit
(** TODO *)
(** set Invalid to the state attribute of a node *)

val invalidate_graph : attr Owl_graph.node -> unit
(** TODO *)
(** iteratively invalidate the nodes in a graph *)

val is_freeze : attr Owl_graph.node -> bool
(** TODO *)
(** check the freeze attribute of a node *)

val freeze : attr Owl_graph.node -> unit
(** TODO *)
(** return the freeze attribute of a node *)

val freeze_descendants : attr Owl_graph.node array -> unit
(** TODO *)
(** iteratively freeze the descendants of a node *)

val freeze_ancestors : attr Owl_graph.node array -> unit
(** TODO *)
(** iteratively freeze the ancestors of a node *)

val pack_arr : A.arr -> arr
(** TODO *)
(** pack an A.arr type input into Arr type *)

val unpack_arr : arr -> A.arr
(** TODO *)
(** unpack input into A.arr type *)

val pack_elt : A.elt -> elt
(** TODO *)
(** pack an A.elt type input into Elt type *)

val unpack_elt : elt -> A.elt
(** TODO *)
(** unpack input into A.elt type *)

val unsafe_assign_arr : arr -> A.arr -> unit
(** TODO *)
(** assign Arr type value *)

val assign_arr : arr -> A.arr -> unit
(** TODO *)
(** assign Arr type value *)

val assign_elt : elt -> A.elt -> unit
(** TODO *)
(** assign Elt type value *)

val float_to_elt : float -> elt
(** TODO *)
(** build an Elt type based on float value *)

val elt_to_float : elt -> float
(** TODO *)
(** retrive a float value from an Elt type value *)
end
Loading

0 comments on commit af695f9

Please sign in to comment.