diff --git a/src/base/algodiff/owl_algodiff_core_sig.ml b/src/base/algodiff/owl_algodiff_core_sig.ml index 9e039b32d..8e1a87256 100644 --- a/src/base/algodiff/owl_algodiff_core_sig.ml +++ b/src/base/algodiff/owl_algodiff_core_sig.ml @@ -13,46 +13,46 @@ module type Sig = sig (** {5 Core functions} *) val tag : unit -> int - (** TODO *) + (** start global tagging counter *) val primal : t -> t - (** TODO *) + (** get primal component of DF or DR type *) val primal' : t -> t - (** TODO *) + (** iteratively get primal component of DF or DR type until the component itself is not DF/DR *) val zero : t -> t - (** TODO *) + (** return a zero value, which type decided by the input value *) val reset_zero : t -> t - (** TODO *) + (** [reset_zero x] iteratively resets all elements included in [x] *) val tangent : t -> t - (** TODO *) + (** get the tangent component of input, if the data type is suitable *) val adjref : t -> t ref - (** TODO *) + (** get the adjref component of input, if the data type is suitable *) val adjval : t -> t - (** TODO *) + (** get the adjval component of input, if the data type is suitableTODO *) val shape : t -> int array - (** TODO *) + (** get the shape of primal' value of input *) val is_float : t -> bool - (** TODO *) + (** check if input is of float value; if input is of type DF/DR, check its primal' value *) val is_arr : t -> bool - (** TODO *) + (** check if input is of ndarray value; if input is of type DF/DR, check its primal' value *) val row_num : t -> int - (** number of rows *) + (** get the shape of primal' value of input; and then get the first dimension *) val col_num : t -> int - (** number of columns *) + (** get the shape of primal' value of input; and then get the second dimension *) val numel : t -> int - (** number of elements *) + (** for ndarray type input, return its total number of elements. *) val clip_by_value : amin:A.elt -> amax:A.elt -> t -> t (** other functions, without tracking gradient *) @@ -61,13 +61,13 @@ module type Sig = sig (** other functions, without tracking gradient *) val copy_primal' : t -> t - (** TODO *) + (** if primal' value of input is ndarray, copy its value in a new AD type ndarray *) val tile : t -> int array -> t - (** TODO *) + (** if primal' value of input is ndarray, apply the tile function *) val repeat : t -> int array -> t - (** TODO *) + (** if primal' value of input is ndarray, apply the repeat function *) val pack_elt : A.elt -> t (** convert from [elt] type to [t] type. *) @@ -93,14 +93,11 @@ module type Sig = sig (* functions to report errors, help in debugging *) val deep_info : t -> string - (** TODO *) val type_info : t -> string - (** TODO *) val error_binop : string -> t -> t -> 'a - (** TODO *) val error_uniop : string -> t -> 'a - (** TODO *) + end diff --git a/src/base/compute/owl_computation_graph_sig.ml b/src/base/compute/owl_computation_graph_sig.ml index fec07641b..b29366ae3 100644 --- a/src/base/compute/owl_computation_graph_sig.ml +++ b/src/base/compute/owl_computation_graph_sig.ml @@ -12,27 +12,26 @@ module type Sig = sig (** {5 Type definition} *) type graph - (** TODO *) (** {5 Core functions} *) val shape_or_value : t -> string - (** TODO *) + (** print shape for ndarrays, whilst value for scalars *) val graph_to_dot : graph -> string - (** TODO *) + (** generate a string that can be written to a .dot file to draw the graph *) val graph_to_trace : graph -> string - (** TODO *) + (** print the graph structure to a string *) val save_graph : 'a -> string -> unit - (** TODO *) + (** save the graph object to a file with given name, using marshall format *) val load_graph : string -> 'a * 'b - (** TODO *) + (** load the graph object from a file with given name *) val collect_rvs : attr Owl_graph.node array -> attr Owl_graph.node array - (** TODO *) + (** traverse each node in the input array, and return the random variable type nodes. *) val invalidate_rvs : graph -> unit (** TODO *) @@ -42,48 +41,43 @@ module type Sig = sig -> output:attr Owl_graph.node array -> string -> graph - (** TODO *) + (** Build a graph based on input nodes, output nodes, and graph name *) val get_inputs : graph -> attr Owl_graph.node array - (** TODO *) + (** get input nodes of a graph *) val get_outputs : graph -> attr Owl_graph.node array - (** TODO *) + (** get output nodes of a graph *) val get_node_arr_val : attr Owl_graph.node -> A.arr - (** TODO *) val get_node_elt_val : attr Owl_graph.node -> A.elt - (** TODO *) val set_node_arr_val : attr Owl_graph.node -> value -> unit - (** TODO *) val set_node_elt_val : attr Owl_graph.node -> value -> unit - (** TODO *) val is_iopair_safe : 'a Owl_graph.node -> 'a Owl_graph.node -> bool - (** TODO *) val make_iopair : graph -> attr Owl_graph.node array -> attr Owl_graph.node array -> unit - (** TODO *) + (** create an iopair between the input nodes and output nodes in a graph *) val update_iopair : graph -> unit - (** TODO *) val remove_unused_iopair : 'a Owl_graph.node array -> 'b array -> 'a Owl_graph.node array * 'b array - (** TODO *) + (** remove unuserd iopair from an array of nodes *) val init_inputs : (attr Owl_graph.node -> value) -> graph -> unit - (** TODO *) + (** initialize inputs nodes of a graph with given function [f] *) val optimise : graph -> unit - (** TODO *) + (** optimise the graph structure *) + end diff --git a/src/base/compute/owl_computation_symbol_sig.ml b/src/base/compute/owl_computation_symbol_sig.ml index 0b03a1d4b..4713189d6 100644 --- a/src/base/compute/owl_computation_symbol_sig.ml +++ b/src/base/compute/owl_computation_symbol_sig.ml @@ -15,43 +15,43 @@ module type Sig = sig (** {5 Core functions} *) val op_to_str : op -> string - (** TODO *) + (** return the name of the operator as string *) val is_random_variable : op -> bool - (** TODO *) + (** check if operator is randon variable *) val refnum : 'a Owl_graph.node -> int - (** TODO *) + (** return the reference number of the given node *) val node_shape : attr Owl_graph.node -> int array - (** TODO *) + (** return the shape of a node *) val node_numel : attr Owl_graph.node -> int - (** TODO *) + (** return the number of elements of a node *) val is_shape_unknown : attr Owl_graph.node -> bool - (** TODO *) + (** check if the shape of the input node is unknown *) val infer_shape_graph : attr Owl_graph.node array -> unit - (** TODO *) + (** automatically infer the shape of input node according to its descendents' shapes *) val shape_to_str : int array option array -> string - (** TODO *) + (** helper function; return the input array in string format. *) val node_to_str : attr Owl_graph.node -> string - (** TODO *) + (** print node's information to string *) val node_to_arr : t -> arr - (** TODO *) + (** Wrap computation graph node in an array type *) val arr_to_node : arr -> t - (** TODO *) + (** Unwrap the array type to get the computation graph node within *) val node_to_elt : t -> elt - (** TODO *) + (** Wrap computation graph node in an Elt type *) val elt_to_node : elt -> t - (** TODO *) + (** Unwrap the Elt type to get the computation graph node within *) val make_node : ?name:string @@ -62,26 +62,26 @@ module type Sig = sig -> ?state:state -> op -> attr Owl_graph.node - (** TODO *) + (** crate a computation graph node *) val make_then_connect : ?shape:int array option array -> op -> attr Owl_graph.node array -> attr Owl_graph.node - (** TODO *) + (** make nodes and then connect parents and children *) val var_arr : ?shape:int array -> string -> arr - (** TODO *) + (** creat a node and wrap in Arr type *) val var_elt : string -> elt - (** TODO *) + (** creat a node and wrap in Elt type *) val const_arr : string -> A.arr -> arr - (** TODO *) + (** get ndarray value from input and create an node and wrap in Arr type *) val const_elt : string -> A.elt -> elt - (** TODO *) + (** get value from input and create an node and wrap in Elt type *) val new_block_id : unit -> int (** [new_block_id ()] returns an unused block id. *) @@ -120,25 +120,25 @@ module type Sig = sig *) val set_value : attr Owl_graph.node -> value array -> unit - (** TODO *) + (** set the arrays of value to cgraph node *) val get_value : attr Owl_graph.node -> value array - (** TODO *) + (** get the arrays of value of cgraph node *) val set_operator : attr Owl_graph.node -> op -> unit - (** TODO *) + (** set the operator of cgraph node *) val get_operator : attr Owl_graph.node -> op - (** TODO *) + (** get the operator of cgraph node *) val set_reuse : attr Owl_graph.node -> bool -> unit - (** TODO *) + (** set reuse attribute in a node *) val get_reuse : attr Owl_graph.node -> bool - (** TODO *) + (** get reuse attribute in a node *) val is_shared : attr Owl_graph.node -> bool - (** TODO *) + (** check of the data block of memory is shared in a node *) val get_shared_nodes : attr Owl_graph.node -> attr Owl_graph.node array (** @@ -147,16 +147,16 @@ module type Sig = sig *) val is_var : attr Owl_graph.node -> bool - (** TODO *) + (** check if the node's operator is Var type *) val is_const : attr Owl_graph.node -> bool - (** TODO *) + (** check if the node's operator is Const type *) val is_node_arr : attr Owl_graph.node -> bool - (** TODO *) + (** check the shape of a node's attr and return if it indicates an ndarray *) val is_node_elt : attr Owl_graph.node -> bool - (** TODO *) + (** check the shape of a node's attr and return if it indicates an elt *) val is_assigned : attr Owl_graph.node -> bool (** @@ -171,53 +171,53 @@ module type Sig = sig *) val is_valid : attr Owl_graph.node -> bool - (** TODO *) + (** check if the state attribute of a node is Valid *) val validate : attr Owl_graph.node -> unit - (** TODO *) + (** set Valid to the state attribute of a node *) val invalidate : attr Owl_graph.node -> unit - (** TODO *) + (** set Invalid to the state attribute of a node *) val invalidate_graph : attr Owl_graph.node -> unit - (** TODO *) + (** iteratively invalidate the nodes in a graph *) val is_freeze : attr Owl_graph.node -> bool - (** TODO *) + (** check the freeze attribute of a node *) val freeze : attr Owl_graph.node -> unit - (** TODO *) + (** return the freeze attribute of a node *) val freeze_descendants : attr Owl_graph.node array -> unit - (** TODO *) + (** iteratively freeze the descendants of a node *) val freeze_ancestors : attr Owl_graph.node array -> unit - (** TODO *) + (** iteratively freeze the ancestors of a node *) val pack_arr : A.arr -> arr - (** TODO *) + (** pack an A.arr type input into Arr type *) val unpack_arr : arr -> A.arr - (** TODO *) + (** unpack input into A.arr type *) val pack_elt : A.elt -> elt - (** TODO *) + (** pack an A.elt type input into Elt type *) val unpack_elt : elt -> A.elt - (** TODO *) + (** unpack input into A.elt type *) val unsafe_assign_arr : arr -> A.arr -> unit - (** TODO *) + (** assign Arr type value *) val assign_arr : arr -> A.arr -> unit - (** TODO *) + (** assign Arr type value *) val assign_elt : elt -> A.elt -> unit - (** TODO *) + (** assign Elt type value *) val float_to_elt : float -> elt - (** TODO *) + (** build an Elt type based on float value *) val elt_to_float : elt -> float - (** TODO *) + (** retrive a float value from an Elt type value *) end diff --git a/src/base/core/owl_lazy.mli b/src/base/core/owl_lazy.mli index fecd366a2..f7a2d5530 100644 --- a/src/base/core/owl_lazy.mli +++ b/src/base/core/owl_lazy.mli @@ -15,97 +15,92 @@ module Make (A : Ndarray_Mutable) : sig (** {5 Type definition} *) type arr - (** TODO *) type elt - (** TODO *) type value - (** TODO *) type attr - (** TODO *) type graph - (** TODO *) (** {5 Type conversion functions} *) val arr_to_value : A.arr -> value - (** TODO *) + (** pack A.arr type to value type *) val value_to_arr : value -> A.arr - (** TODO *) + (** retrieve A.arr type from value type input *) val elt_to_value : A.elt -> value - (** TODO *) + (** pack A.elt type to value type *) val value_to_elt : value -> A.elt - (** TODO *) + (** retrieve A.elt type from value type input *) val value_to_float : value -> float - (** TODO *) + (** retrieve float type from value type input *) val node_to_arr : attr node -> arr - (** TODO *) + (** get Arr type from node *) val arr_to_node : arr -> attr node - (** TODO *) + (** pack Arr type into node *) val node_to_elt : attr node -> elt - (** TODO *) + (** get Elt type from node *) val elt_to_node : elt -> attr node - (** TODO *) + (** pack Elt type into node *) val pack_arr : A.arr -> arr - (** TODO *) + (** pack A.arr type into Arr type *) val unpack_arr : arr -> A.arr - (** TODO *) + (** unpack Arr type into A.arr *) val pack_elt : A.elt -> elt - (** TODO *) + (** pack A.elt type into Elt type *) val unpack_elt : elt -> A.elt - (** TODO *) + (** unpack Elt type into A.elt *) val float_to_elt : float -> elt - (** TODO *) + (** build Elt type to float *) val elt_to_float : elt -> float - (** TODO *) + (** get float value from Elt type *) (** {5 Utility functions} *) val graph_to_dot : graph -> string - (** TODO *) + (** convert graph to dot string, which can be saved into file and later rendered into figures to show graph structure *) val graph_to_trace : graph -> string - (** TODO *) + (** print graph structure *) (** {5 Create variables} *) val var_arr : ?shape:int array -> string -> arr - (** TODO *) + (** create Arr *) val var_elt : string -> elt - (** TODO *) + (** create Elt *) val const_arr : string -> A.arr -> arr - (** TODO *) + (** create Arr with constant value *) val const_elt : string -> A.elt -> elt - (** TODO *) + (** create Elt with constant value *) val assign_arr : arr -> A.arr -> unit - (** TODO *) + (** assign A.arr value to Arr *) val assign_elt : elt -> A.elt -> unit - (** TODO *) + (** assign A.elt value to Elt *) val unsafe_assign_arr : arr -> A.arr -> unit - (** TODO *) + (** assign A.arr value to Arr *) (** {5 Maths functions} *) @@ -587,32 +582,33 @@ module Make (A : Ndarray_Mutable) : sig (** {5 Evaluation functions} *) val make_graph : input:attr node array -> output:attr node array -> string -> graph - (** TODO *) + (** build a computation graph *) val get_inputs : graph -> attr node array - (** TODO *) + (** get input nodes of graph *) val get_outputs : graph -> attr node array - (** TODO *) + (** get output nodes of graph *) val make_iopair : graph -> attr node array -> attr node array -> unit - (** TODO *) + (** connect iopairs in a graph *) val update_iopair : graph -> unit - (** TODO *) + (** update iopairs in a graph *) val init_inputs : (attr node -> value) -> graph -> unit - (** TODO *) + (** initialize input nodes of a graph with given function [f] *) val optimise : graph -> unit - (** TODO *) + (** optimise graph structures *) val eval_elt : elt array -> unit - (** TODO *) + (** evaluate each Elt element in an array *) val eval_arr : arr array -> unit - (** TODO *) + (** evaluate each Arr element in an array *) val eval_graph : graph -> unit - (** TODO *) + (** evaluate all nodes in a computation graph *) + end diff --git a/src/base/types/owl_types_computation_device.ml b/src/base/types/owl_types_computation_device.ml index 3bb4ff0c2..3bcc6be9f 100644 --- a/src/base/types/owl_types_computation_device.ml +++ b/src/base/types/owl_types_computation_device.ml @@ -9,34 +9,25 @@ module type Sig = sig (** {5 Type definition} *) type device - (** TODO *) type value - (** TODO *) (** {5 Core functions} *) val make_device : unit -> device - (** TODO *) val arr_to_value : A.arr -> value - (** TODO *) val value_to_arr : value -> A.arr - (** TODO *) val elt_to_value : A.elt -> value - (** TODO *) val value_to_elt : value -> A.elt - (** TODO *) val value_to_float : value -> float - (** TODO *) val is_arr : value -> bool - (** TODO *) val is_elt : value -> bool - (** TODO *) + end