-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathppx_nn_op.ml
277 lines (268 loc) · 12.2 KB
/
ppx_nn_op.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
open Base
open Ppxlib
open Ppx_nn_shared
let ndarray_op ?desc_label ?axis_labels ?label expr =
let loc = expr.pexp_loc in
let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
let edims dims = Ast_builder.Default.elist ~loc dims in
let op =
match (axis_labels, label) with
| None, None -> [%expr FDSL.ndarray]
| Some axis_labels, None -> [%expr FDSL.ndarray ~axis_labels:[%e axis_labels]]
| None, Some label -> [%expr FDSL.ndarray ~label:[%e label]]
| Some axis_labels, Some label -> [%expr FDSL.ndarray ~axis_labels:[%e axis_labels] ~label:[%e label]]
in
[%expr
[%e op] ?desc_label:[%e opt_pat2string ~loc desc_label] ~batch_dims:[%e edims batch_dims]
~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] [%e values]]
let make_vb ?value ~loc ~str_loc ~ident string =
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
let value = match value with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
let v = [%expr FDSL.params ?value:[%e value] [%e string]] in
let vb = Ast_helper.Vb.mk ~loc pat v in
(pat, vb)
let make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc string =
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
let dims =
let loc = dims_loc in
List.fold_right dims ~init:[%expr []] ~f:(fun d ds -> [%expr [%e d] :: [%e ds]])
in
let v = [%expr FDSL.params ~output_dims:[%e dims] [%e string]] in
let vb = Ast_helper.Vb.mk ~loc pat v in
(pat, vb)
let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
let values, batch_dims, output_dims, input_dims = ndarray_constant init_nd in
let v =
if not @@ List.is_empty batch_dims then
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
"ppx_ocannl params cannot have batch dims: define a constant or remove the array syntax."
else
let edims dims = Ast_builder.Default.elist ~loc dims in
let op =
match axis_labels with
| None -> [%expr FDSL.params]
| Some axis_labels -> [%expr FDSL.params ~axis_labels:[%e axis_labels]]
in
[%expr
[%e op] ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] ~values:[%e values]
[%e string]]
in
let vb = Ast_helper.Vb.mk ~loc pat v in
(pat, vb)
let rec translate ?desc_label expr =
let loc = expr.pexp_loc in
match expr with
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
(no_vbs, [%expr FDSL.number ?desc_label:[%e opt_pat2string ~loc desc_label] [%e expr]])
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
(no_vbs, [%expr FDSL.number (Float.of_int [%e expr])])
| [%expr
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
[%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in
( no_vbs,
[%expr FDSL.number ?desc_label:[%e opt_pat2string ~loc desc_label] ~axis_label:[%e axis] [%e f]] )
| [%expr
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in
( no_vbs,
[%expr
FDSL.number ?desc_label:[%e opt_pat2string ~loc desc_label] ~axis_label:[%e axis]
(Float.of_int [%e i])] )
| [%expr
[%e? expr1]
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [%e? expr2]]
when String.contains spec_str '>' ->
let vbs1, e1 = translate expr1 in
let vbs2, e2 = translate expr2 in
( reduce_vbss [ vbs1; vbs2 ],
[%expr FDSL.einsum ?desc_label:[%e opt_pat2string ~loc desc_label] [%e spec] [%e e1] [%e e2]] )
| [%expr [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
when String.contains spec_str '>' ->
let vbs1, e1 = translate expr1 in
(vbs1, [%expr FDSL.einsum1 ?desc_label:[%e opt_pat2string ~loc desc_label] [%e spec] [%e e1]])
| [%expr
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
[%e?
( { pexp_desc = Pexp_constant (Pconst_integer _); pexp_loc = dims_loc; _ }
| { pexp_desc = Pexp_ident _; pexp_loc = dims_loc; _ } ) as d]] ->
let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims:(convert_dsl_dims [ d ]) ~dims_loc s in
(Map.singleton (module String) ident vb, pat2expr pat)
| [%expr
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
[%e?
({ pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ })
as init_nd]] ->
let pat, vb = make_vb_nd ~loc ~str_loc ~ident ~init_nd s in
(Map.singleton (module String) ident vb, pat2expr pat)
| [%expr
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
[%e? { pexp_desc = Pexp_tuple dims; pexp_loc = dims_loc; _ }]] ->
let dims = convert_dsl_dims dims in
let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc s in
(Map.singleton (module String) ident vb, pat2expr pat)
| { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } ->
let pat, vb = make_vb ~loc ~str_loc ~ident expr in
(Map.singleton (module String) ident vb, pat2expr pat)
| { pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
(no_vbs, ndarray_op ?desc_label expr)
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] ->
(* We need to hardcode these two patterns to prevent the numbers from being converted
to formulas. *)
let vbs, e1 = translate expr1 in
(vbs, [%expr FDSL.O.( **. ) ?desc_label:[%e opt_pat2string ~loc desc_label] [%e e1] [%e f]])
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
let vbs, e1 = translate expr1 in
( vbs,
[%expr FDSL.O.( **. ) ?desc_label:[%e opt_pat2string ~loc desc_label] [%e e1] (Float.of_int [%e i])]
)
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
let vbs1, e1 = translate ?desc_label expr1 in
let vbs2, e2 = translate expr2 in
let vbs3, e3 = translate expr3 in
(reduce_vbss [ vbs1; vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
| [%expr [%e? expr1] [%e? expr2]] ->
let vbs1, e1 = translate ?desc_label expr1 in
let vbs2, e2 = translate expr2 in
(Map.merge_skewed vbs1 vbs2 ~combine:(fun ~key:_ _v1 v2 -> v2), [%expr [%e e1] [%e e2]])
| [%expr fun ~config [%p? pat1] [%p? pat2] -> [%e? body]] ->
(* TODO(#38): generalize config to any number of labeled arguments with any labels. *)
let vbs, body = translate ?desc_label body in
(no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat1] [%p pat2] -> [%e body]]]])
| [%expr fun ~config [%p? pat] -> [%e? body]] ->
(* TODO(#38): generalize config to any number of labeled arguments with any labels. *)
let vbs, body = translate ?desc_label body in
(no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat] -> [%e body]]]])
| [%expr fun [%p? pat1] [%p? pat2] -> [%e? body]] ->
let vbs, body = translate ?desc_label body in
(vbs, [%expr fun [%p pat1] [%p pat2] -> [%e body]])
| [%expr fun [%p? pat] -> [%e? body]] ->
let vbs, body = translate ?desc_label body in
(vbs, [%expr fun [%p pat] -> [%e body]])
| [%expr
while [%e? test_expr] do
[%e? body_expr]
done] ->
let vbs, body = translate ?desc_label body_expr in
( vbs,
[%expr
while [%e test_expr] do
[%e body]
done] )
| [%expr
for [%p? pat] = [%e? init] to [%e? final] do
[%e? body_expr]
done] ->
let vbs, body = translate ?desc_label body_expr in
( vbs,
[%expr
for [%p pat] = [%e init] to [%e final] do
[%e body]
done] )
| [%expr
for [%p? pat] = [%e? init] downto [%e? final] do
[%e? body_expr]
done] ->
let vbs, body = translate ?desc_label body_expr in
( vbs,
[%expr
for [%p pat] = [%e init] downto [%e final] do
[%e body]
done] )
| [%expr
[%e? expr1];
[%e? expr2]] ->
let vbs1, e1 = translate expr1 in
let vbs2, e2 = translate ?desc_label expr2 in
( reduce_vbss [ vbs1; vbs2 ],
[%expr
[%e e1];
[%e e2]] )
| [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
let vbs2, e2 = translate ?desc_label expr2 in
let vbs3, e3 = translate ?desc_label expr3 in
(reduce_vbss [ vbs2; vbs3 ], [%expr if [%e expr1] then [%e e2] else [%e e3]])
| [%expr if [%e? expr1] then [%e? expr2]] ->
let vbs2, e2 = translate ?desc_label expr2 in
(vbs2, [%expr if [%e expr1] then [%e e2]])
| { pexp_desc = Pexp_match (expr1, cases); _ } ->
let vbss, cases =
List.unzip
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
let vbs, pc_rhs = translate ?desc_label pc_rhs in
(vbs, { c with pc_rhs }))
in
(reduce_vbss vbss, { expr with pexp_desc = Pexp_match (expr1, cases) })
| { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
let vbss1, bindings =
List.unzip
@@ List.map bindings ~f:(fun binding ->
let vbs, pvb_expr = translate ~desc_label:binding.pvb_pat binding.pvb_expr in
(vbs, { binding with pvb_expr }))
in
let vbs2, body = translate ?desc_label body in
let all_bindings = (Map.data @@ reduce_vbss vbss1) @ bindings @ Map.data vbs2 in
(no_vbs, { expr with pexp_desc = Pexp_let (recflag, all_bindings, body) })
| { pexp_desc = Pexp_open (decl, body); _ } ->
let vbs, body = translate ?desc_label body in
(vbs, { expr with pexp_desc = Pexp_open (decl, body) })
| { pexp_desc = Pexp_letmodule (name, module_expr, body); _ } ->
let vbs, body = translate ?desc_label body in
(vbs, { expr with pexp_desc = Pexp_letmodule (name, module_expr, body) })
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
(no_vbs, [%expr [%e expr] ?desc_label:[%e opt_pat2string ~loc desc_label]])
| expr -> (no_vbs, expr)
let expr_expander ~loc ~path:_ payload =
match payload with
| { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
(* We are at the %ocannl annotation level: do not tranlsate the body. *)
let vbss, bindings =
List.unzip
@@ List.map bindings ~f:(fun vb ->
let vbs, v = translate ~desc_label:vb.pvb_pat vb.pvb_expr in
( vbs,
{
vb with
pvb_expr =
[%expr
let open! FDSL.O in
[%e v]];
} ))
in
let expr = { payload with pexp_desc = Pexp_let (recflag, bindings, body) } in
let_opt ~loc (reduce_vbss vbss) expr
| expr ->
let vbs, expr = translate expr in
let_opt ~loc vbs expr
let flatten_str ~loc ~path:_ items =
match items with
| [ x ] -> x
| _ ->
Ast_helper.Str.include_
{ pincl_mod = Ast_helper.Mod.structure items; pincl_loc = loc; pincl_attributes = [] }
let translate_str ({ pstr_desc; pstr_loc = loc; _ } as str) =
match pstr_desc with
| Pstr_eval (expr, attrs) ->
let vbs, expr = translate expr in
let expr = let_opt ~loc vbs expr in
{ str with pstr_desc = Pstr_eval (expr, attrs) }
| Pstr_value (recf, bindings) ->
let f vb =
let loc = vb.pvb_loc in
let vbs, v = translate ~desc_label:vb.pvb_pat vb.pvb_expr in
let v = let_opt ~loc vbs v in
{
vb with
pvb_expr =
[%expr
let open! FDSL.O in
[%e v]];
}
in
{ str with pstr_desc = Pstr_value (recf, List.map bindings ~f) }
| _ -> str
let str_expander ~loc ~path (payload : structure_item list) =
flatten_str ~loc ~path @@ List.map payload ~f:translate_str