-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathoperation.ml
328 lines (272 loc) · 13.4 KB
/
operation.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
(** Computational primitives for neural networks, integrating [Formula] with [Code]. *)
open Base
module CDSL = Code.CDSL
let add =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections = n =: n1 + n2 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
n1.grad =+ n.grad || n2.grad =+ n.grad
in
Formula.binop ~compose_op:Pointwise_bin ~op_label:"+" ~op_body ~grad_body
let pointmul =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections = n =: n1 * n2 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
n1.grad =+ n.grad * n2 || n2.grad =+ n1 * n.grad
in
Formula.binop ~compose_op:Pointwise_bin ~op_label:"*." ~op_body ~grad_body
(* N1: AxB, N2 BxC, N: AxC, A: output of N1, B: input/output of N1/N2, C: input of N2.
Although the matrix algebra would require that we insert additional transposes in gradient multiplies:
AxB = AxC * CxB = AxC * (BxC)^T -> N1g += Ng * N2v^T,
BxC = BxA * AxC = (AxB)^T * AxC -> N2g += N1v^T * Ng,
in our setup there is no transposing to do, since the projections produce correct indices for their
corresponding matrices. *)
let matmul =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections = n =:+ n1 * n2 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
n1.grad =+ n.grad * n2 || n2.grad =+ n1 * n.grad
in
Formula.binop ~compose_op:Compose ~op_label:"*" ~op_body ~grad_body
(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of
matrix multiplication, inner and outer products, etc.
Note that ["a,b->c"] from [numpy] is ["a;b=>c"] in OCANNL, since ["->"] is used to separate the input
and the output axes. *)
let einsum ?desc_label spec =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections = n =:+ n1 * n2 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
n1.grad =+ n.grad * n2 || n2.grad =+ n1 * n.grad
in
Formula.binop ?desc_label ~compose_op:(Einsum spec) ~op_label:";=>" ~op_body ~grad_body
(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract diagonals,
compute traces etc.
Note that ["a->c"] from [numpy] is ["a=>c"] in OCANNL, since ["->"] is used to separate the input
and the output axes. *)
let einsum1 ?desc_label spec =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~projections = n =:+ n1 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~projections = n1.grad =+ n.grad in
Formula.unop ?desc_label ~transpose_op:(Permute spec) ~op_label:"=>" ~op_body ~grad_body
let relu =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~projections = n =: !/n1 ~projections in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~projections = n1.grad =+ n -?/ n.grad in
Formula.unop ~transpose_op:Pointwise_un ~op_label:"r" ~op_body ~grad_body
let subtensor_label ~over_kind ~from_left ~other_axes_pointwise =
let kind_spec = match over_kind with Shape.AxisKey.Batch -> "|" | Input -> "/" | Output -> "-" in
let pointwise_spec = if other_axes_pointwise then "." else "^" in
if from_left then "@" ^ pointwise_spec ^ kind_spec else "@" ^ kind_spec ^ pointwise_spec
let dynamic_subtensor ?indexed_dims ~over_kind ~from_left ~other_axes_pointwise =
let open Code in
let module NFDSL = struct
module O = struct end
end in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections = n =: n1 -@> n2 in
let%nn_cd grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
(* [projections] tracks the dynamic indexing for [n] (and not [n1]) as a slice.
[-@>] simply means [Arg1]: take the first argument, ignore the second argument. *)
n1.grad =+ n.grad -@> n2
in
let compose_op = Shape.Dynamic_index { over_kind; from_left; other_axes_pointwise; indexed_dims } in
let op_label = subtensor_label ~over_kind ~from_left ~other_axes_pointwise in
Formula.binop ~compose_op ~op_label ~op_body ~grad_body
module NFO_without_pow = struct
let ( * ) = matmul ~is_form:false
let ( *. ) = pointmul ~is_form:false
let ( + ) = add ~is_form:false
let ( !/ ) = relu ~is_form:false
let ( !. ) = Formula.number ~is_form:false
let ( !.. ) ?desc_label i = Formula.number ?desc_label ~is_form:false @@ Float.of_int i
let ( - ) ?desc_label m1 m2 = ( + ) ?desc_label m1 (!.(-1.) *. m2)
let ( ~- ) ?desc_label m = ( *. ) ?desc_label !.(-1.) m
let ( @.| ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:true ~other_axes_pointwise:true ~is_form:false
let ( @./ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:true ~other_axes_pointwise:true ~is_form:false
let ( @.- ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:true ~other_axes_pointwise:true
~is_form:false
let ( @^| ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:true ~other_axes_pointwise:false
~is_form:false
let ( @^/ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:true ~other_axes_pointwise:false
~is_form:false
let ( @^- ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:true ~other_axes_pointwise:false
~is_form:false
let ( @|. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:false ~other_axes_pointwise:true
~is_form:false
let ( @/. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:false ~other_axes_pointwise:true
~is_form:false
let ( @-. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:false ~other_axes_pointwise:true
~is_form:false
let ( @|^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:false ~other_axes_pointwise:false
~is_form:false
let ( @/^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:false ~other_axes_pointwise:false
~is_form:false
let ( @-^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:false ~other_axes_pointwise:false
~is_form:false
end
let rec pointpow ?desc_label ~is_form p m1 : Formula.t =
let module NFDSL = struct
module O = NFO_without_pow
end in
let open Code in
let p_f = Formula.number ~is_form p in
let%nn_cd op_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~(n2 : NodeUI.t) ~projections =
n =: n1 ** n2 ~projections
in
let%nn_cd grad_body =
if not is_form then fun ~n:_ ~n1:_ ~n2:_ ~projections:_ -> Noop
else if Float.equal p 2.0 then fun ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~n2:_ ~projections ->
n1.grad =+ p_f *. m1 * n.grad
else fun ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~n2:_ ~projections ->
n1.grad =+ p_f *. (m1 **. (p -. 1.)) * n.grad
in
Formula.binop ?desc_label ~compose_op:Pointwise_bin ~op_label:"**." ~op_body ~grad_body ~is_form m1 p_f
let range ?desc_label ~is_form ?axis_label upto =
Formula.term ?desc_label ~is_form ~needs_gradient:false
~label:("0" ^ "..." ^ Int.to_string upto)
~batch_dims:[] ~input_dims:[]
~output_dims:[ Dim (upto + 1) ]
?axis_labels:axis_label ~init_op:Range_over_offsets ()
let range_of_shape ?desc_label ~is_form ?(batch_dims = []) ?(input_dims = []) ?(output_dims = []) ?axis_labels
() =
let dims = Array.concat_map [| batch_dims; output_dims; input_dims |] ~f:Array.of_list in
Formula.term ?desc_label ~is_form ~needs_gradient:false ~batch_dims ~input_dims ~output_dims ?axis_labels
~label:("r" ^ NodeUI.dims_to_string dims)
~init_op:Range_over_offsets ()
(** In {!Formula.term} the omitted axes are {!Shape.Unknown} -- to be inferred, here they are known and empty. *)
let data ?desc_label ?axis_labels ?(needs_gradient = false) ~label ?(batch_dims = []) ?(input_dims = [])
?(output_dims = []) fetch_op =
Formula.term ?desc_label ~label ~is_form:true ~needs_gradient ~batch_dims ~input_dims ~output_dims
?axis_labels ~fetch_op ()
(** Non-form computations that happen at the end (potentially in parallel). *)
let result ?desc_label ?axis_labels ~label ?(batch_dims = []) ?(input_dims = []) ?(output_dims = [])
postprocess_op =
Formula.term ?desc_label ~label ~is_form:false ~needs_gradient:false ~batch_dims ~input_dims ~output_dims
?axis_labels ~postprocess_op ()
let assign =
let module NFDSL = struct
module O = struct end
end in
let%nn_cd assign ~(lhs : NodeUI.tensor_ptr) ~(rhs : NodeUI.tensor_ptr) ~projections =
lhs =: rhs ~projections
in
assign
let assign_op field ~(n : NodeUI.t) ~(n1 : NodeUI.t) ~projections =
assign ~lhs:(field n) ~rhs:(field n1) ~projections
(** A [stop_gradient] is an identity in the forward pass and a no-op in the backprop pass. *)
let stop_gradient =
let grad_body ~n:_ ~n1:_ ~projections:_ = Code.Noop in
let op_body = assign_op @@ Code.CDSL.data_of_node Value in
Formula.unop ~transpose_op:Pointwise_un ~op_label:"stop_grad" ~op_body ~grad_body ~is_form:true
(** A [stop_broadcast] mutates the partially-inferred shape of a formula in-place, substituting-in
a [Fixed] marker on the dimensions. This way we avoid introducing a new node. *)
let stop_broadcast m = Shape.set_dims_type m.Formula.shape Shape.fixed
(** [identity] introduces a new node, which is an identity in both the forward and backward pass. *)
let identity ?desc_label ~is_form m =
let grad_body ~(n : NodeUI.t) ~(n1 : NodeUI.t) = assign_op (Code.CDSL.data_of_node Grad) ~n:n1 ~n1:n in
let op_body = assign_op @@ Code.CDSL.data_of_node Value in
Formula.(
unop ?desc_label ~init_shape:m.shape ~transpose_op:Pointwise_un ~op_label:"=" ~op_body ~grad_body ~is_form)
module O = struct
let ( * ) = matmul ~is_form:true
let ( *. ) = pointmul ~is_form:true
let ( + ) = add ~is_form:true
let ( **. ) ?desc_label base exp = pointpow ?desc_label exp base ~is_form:true
let ( !/ ) = relu ~is_form:true
let ( !~ ) ?desc_label label = Formula.params ?desc_label label
let ( !. ) = Formula.number ~is_form:true
let ( !.. ) ?desc_label i = Formula.number ?desc_label ~is_form:true @@ Float.of_int i
let ( - ) ?desc_label m1 m2 = ( + ) ?desc_label m1 (!.(-1.) *. m2)
let ( ~- ) ?desc_label m = ( *. ) ?desc_label !.(-1.) m
let ( /. ) ?desc_label m1 m2 = ( *. ) ?desc_label m1 (m2 **. -1.0)
let ( @.| ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:true ~other_axes_pointwise:true ~is_form:true
let ( @./ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:true ~other_axes_pointwise:true ~is_form:true
let ( @.- ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:true ~other_axes_pointwise:true ~is_form:true
let ( @^| ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:true ~other_axes_pointwise:false ~is_form:true
let ( @^/ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:true ~other_axes_pointwise:false ~is_form:true
let ( @^- ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:true ~other_axes_pointwise:false
~is_form:true
let ( @|. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:false ~other_axes_pointwise:true ~is_form:true
let ( @/. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:false ~other_axes_pointwise:true ~is_form:true
let ( @-. ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:false ~other_axes_pointwise:true
~is_form:true
let ( @|^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Batch ~from_left:false ~other_axes_pointwise:false
~is_form:true
let ( @/^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Input ~from_left:false ~other_axes_pointwise:false
~is_form:true
let ( @-^ ) =
dynamic_subtensor ~over_kind:Shape.AxisKey.Output ~from_left:false ~other_axes_pointwise:false
~is_form:true
end
module FDSL = struct
include Formula.FDSL
module O = O
let einsum ?desc_label s = einsum ?desc_label s ~is_form:true
let einsum1 ?desc_label s = einsum1 ?desc_label s ~is_form:true
let range = range ~is_form:true
let range_of_shape = range_of_shape ~is_form:true
let data = data
let stop_broadcast = stop_broadcast
let stop_gradient = stop_gradient
let init_const ~l ?(b = []) ?(i = []) ?(o = []) cs =
term ~label:l ~needs_gradient:false ~batch_dims:b ~input_dims:i ~output_dims:o ~init_op:(Constant_fill cs)
()
let init_param ~l ?(b = []) ?(i = []) ?(o = []) cs =
term ~label:l ~needs_gradient:true ~batch_dims:b ~input_dims:i ~output_dims:o ~init_op:(Constant_fill cs)
()
end
module NFO = struct
include NFO_without_pow
let ( **. ) ?desc_label base exp = pointpow ?desc_label exp base ~is_form:false
let ( /. ) ?desc_label m1 m2 = ( *. ) ?desc_label m1 (m2 **. -1.0)
end
module NFDSL = struct
include Formula.NFDSL
module O = NFO
let einsum ?desc_label s = einsum ?desc_label s ~is_form:false
let einsum1 ?desc_label s = einsum1 ?desc_label s ~is_form:false
let term = Formula.term ~is_form:false
let result = result
let range = range ~is_form:false
let range_of_shape = range_of_shape ~is_form:false
end