-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathshape.ml
1687 lines (1583 loc) · 84.3 KB
/
shape.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(** Tensor shape types and inference. *)
open Base
(** An index pointing to any of a shape's axes, including the kind of the axis ([Batch, Input, Output])
and the position (which is counted from the end to facilitate broadcasting).
Note the following inconsistency due to differing conventions in function notation and matrix notation:
for label specifications and einsum notation, we write "batch|inputs->outputs", but when we convert
a shape to an [Code] index we do it in the order [[batch; outputs; inputs]]. *)
module AxisKey = struct
module T = struct
type kind = Batch | Input | Output [@@deriving equal, compare, sexp, variants]
type t = {
in_axes : kind;
from_end : int;
(** Axes are indexed from the end, to avoid reindexing when broadcasting; starting with [1]. *)
}
[@@deriving equal, compare, sexp]
let to_string key =
(match key.in_axes with Batch -> "bch" | Input -> "inp" | Output -> "out")
^ Int.to_string key.from_end
end
include T
include Comparator.Make (T)
end
type 'a axis_map = 'a Map.M(AxisKey).t [@@deriving compare, sexp]
let num_parallel_tasks = ref 1
type dim =
| Dim of int (** An axis of the given size. *)
| Parallel
(** An axis of size [!num_parallel_tasks], such that if an index is derived for it, the index
will be [Task_id]. *)
| Frozen of int
(** An axis that should be indexed at a single position during a single `refresh_session`:
a dynamic index into it will be a [Frozen_recipient]. *)
[@@deriving equal, compare, sexp, variants]
type parsed_axis_labels = {
bcast_batch : bool;
bcast_input : bool;
bcast_output : bool;
given_batch : int;
given_input : int;
given_output : int;
labels : (string, dim) Either.t axis_map;
}
[@@deriving compare, sexp, fields]
(** The labels are strings assigned to [AxisKey] axes. Moreover the [bcast_] fields represent whether
additional leading axes are allowed (corresponding to the dot-ellipsis syntax for broadcasting).
The [given_] fields count the number of specified axes of the corresponding kind in [labels]. *)
let bcast_of_kind = function
| AxisKey.Batch -> bcast_batch
| AxisKey.Input -> bcast_input
| AxisKey.Output -> bcast_output
let given_of_kind = function
| AxisKey.Batch -> given_batch
| AxisKey.Input -> given_input
| AxisKey.Output -> given_output
let dim_to_string = function
| Dim d -> Int.to_string d
| Frozen d -> "frozen " ^ Int.to_string d
| Parallel -> "parallel"
let dim_1 = function Dim 1 | Frozen 1 -> true | _ -> false
let map_dim ~f = function Dim d -> Dim (f d) | Frozen d -> Frozen (f d) | Parallel -> Parallel
type dims =
| Given of dim list
(** User-provided dimensions. They will not change but will be broadcasted to bigger sizes. *)
| Fixed of dim list
(** User-provided dimensions that will fail if used in a different size context, even if broadcastable.
Note that [Operation.stop_broadcast] implements additional shape logic:
it converts the (bottom-up i.e. partially inferred) shape into a [Fixed] variant. *)
| Inferred of dim list
(** Dimensions that will itself change to a bigger size: they adapt to the broadcasted size. *)
| Unknown
(** User-provided and will be replaced through inference. Prefer using [Unknown] to [Inferred []]. *)
[@@deriving equal, compare, sexp, variants]
let map_dims ~f = function
| Given dims -> Given (f dims)
| Fixed dims -> Fixed (f dims)
| Inferred dims -> Inferred (f dims)
| Unknown -> Inferred (f [])
type deduce_dims = Not_constrained | Input_equals_output | Input_output_scale of float
[@@deriving compare, sexp, variants]
(** Converts dimensions according to the specification. Note that scalar axes (1D) are not scaled,
for compatibility with broadcasting.
Note that in practice [from] will be [Unknown] or [Inferred] dimensions, making it of little relevance
how the [Given] and [Fixed] cases are interpreted here. *)
let deduce_dims from : deduce_dims -> dims = function
| Not_constrained -> Unknown
| Input_equals_output -> (
match from with Given dims | Fixed dims -> Inferred dims | Inferred _ | Unknown -> from)
| Input_output_scale sc -> (
match from with
| Unknown -> Unknown
| Given dims | Fixed dims | Inferred dims ->
Inferred
(List.map dims
~f:(map_dim ~f:(fun d -> if d = 1 then 1 else Float.(iround_exn ~dir:`Up @@ (sc * of_int d)))))
)
type t = {
mutable batch : dims;
mutable input : dims;
mutable output : dims;
mutable axis_labels : string axis_map;
deduce_within_shape_constraints : deduce_dims;
(** Intended for terminal node cases where both [input] and [output] are initially
unknown. It makes it trivial to implement dimension-preserving hidden layers: just set
[deduce_within_shape_constraints=Input_equals_output]. *)
id : int; (** A node that has the same shape as this shape. *)
}
[@@deriving fields, sexp]
(** The datatype from which the actual Code shapes are computed.
Mutability is sufficient to perform inference, since there is no need for backtracking and
no explicit unification variables for now. [Unknown] stands for "not yet specified". *)
let dims_of_kind = function AxisKey.Batch -> batch | AxisKey.Input -> input | AxisKey.Output -> output
let map_over_kind ~f kind sh =
match kind with
| AxisKey.Batch -> { sh with batch = f sh.batch }
| AxisKey.Input -> { sh with input = f sh.input }
| AxisKey.Output -> { sh with output = f sh.output }
let update_kind ~f kind sh =
match kind with
| AxisKey.Batch -> sh.batch <- f sh.batch
| AxisKey.Input -> sh.input <- f sh.input
| AxisKey.Output -> sh.output <- f sh.output
let list_of_dims = function Given ls | Fixed ls | Inferred ls -> ls | Unknown -> []
let shift_axes_of_kind kind sh ~f =
Map.of_alist_exn (module AxisKey)
@@ List.filter_map (Map.to_alist sh.axis_labels) ~f:(fun (({ in_axes; from_end }, v) as kv) ->
if not @@ AxisKey.equal_kind kind in_axes then Some kv
else
let from_end = f from_end in
if from_end < 1 then None else Some ({ in_axes; from_end }, v))
let append_all_axes ?prefix ?suffix ~main () =
let affix, left =
match (prefix, suffix) with
| Some affix, None -> (affix, true)
| None, Some affix -> (affix, false)
| _ -> assert false
in
let ap affix dims = if left then list_of_dims affix @ dims else dims @ list_of_dims affix in
let ap affix = function
| Unknown -> Inferred (list_of_dims affix)
| Inferred dims -> Inferred (ap affix dims)
| Given dims -> Given (ap affix dims)
| Fixed dims -> Fixed (ap affix dims)
in
let batch = ap affix.batch main.batch in
let input = ap affix.input main.input in
let output = ap affix.output main.output in
let f ({ AxisKey.in_axes; from_end }, v) =
let offset = List.length @@ list_of_dims @@ dims_of_kind in_axes (if left then main else affix) in
({ AxisKey.in_axes; from_end = from_end + offset }, v)
in
let axis_labels =
if left then
Map.to_alist affix.axis_labels |> List.map ~f
|> Fn.flip List.append @@ Map.to_alist main.axis_labels
|> Map.of_alist_exn (module AxisKey)
else
Map.to_alist main.axis_labels |> List.map ~f
|> Fn.flip List.append @@ Map.to_alist affix.axis_labels
|> Map.of_alist_exn (module AxisKey)
in
{ main with batch; input; output; axis_labels }
type compose_type =
| Pointwise_bin (** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *)
| Compose
(** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of
[fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix multiply). *)
| Einsum of string
(** The [einsum] syntax: LABELS1;LABELS2=>LABELS3, where LABELSi are labels specifications.
Note that currently [Compose] is not redundant with [Einsum], because it enables more shape
inference: [Einsum] is limited to [Pointwise_bin]-like broadcasting, while [Compose] broadcasts
inputs of the "operator" against outputs of the "operand" (matching up an arbitrary number of axes).
The [axis_labels] use pseudo-labels local to the notation, to line up the axes.
For [Einsum (ls1^";"^ls2^"=>"^ls3)], the symmetric difference / disjunctive union of [ls1] and [ls2]'s
pseudo-labels should be equal to [ls3] pseudo-labels.
Currently, we support two variants of the [einsum] syntax: either all the axes are provided,
or all input, output axes are provided but none of the batch axes.
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
| Dynamic_index of {
over_kind : AxisKey.kind;
from_left : bool;
other_axes_pointwise : bool;
indexed_dims : int list option;
}
(** Uses RHS2 as an index into RHS1. The values of RHS2 along the last (output) axis are used to
fix the RHS1 [over_kind] axes. If RHS1 has more [over_kind] axes than the size of RHS2's last
axis, the remaining right axes are kept if [from_left] is true, otherwise the left axes
are kept. The [Parallel] axes are preserved -- skipped over, never dynamically indexed.
The fixed (indexed-into) axes are dropped from the shape of LHS. If RHS2 has more than one axis,
[other_axes_pointwise] decides what to do with the other axes: if true, they are traversed
pointwise with the corresponding axes of RHS1. For the pointwise alignment, we drop the last
output axis of RHS2 and the fixed axes of RHS1. If [other_axes_pointwise] is false,
the other axes of RHS2 are prepended to the remaining axes of RHS1 (of the corresponding
kind). ([other_axes_pointwise] being true is akin to an inner product, being false is akin to
an outer product.) If [indexed_dims] are not given, the shape of the indexed tensor (RHS1) must be given.
Otherwise, [indexed_dims] provide the dimensions of the indexed axes.
If otherwise unknown, [RHS2]'s output axis is inferred to be [Dim 1]. *)
[@@deriving sexp, equal]
type transpose_type =
| Transpose (** Swaps inputs and outputs of a shape, preserves batch axes. *)
| Pointwise_un (** Preserves the shape. *)
| Permute of string
(** [Permute (ls1^"=>"^ls2)] is a variant of the [einsum] syntax [Einsum (ls1^";"^ls1^"=>"^ls2)].
Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *)
[@@deriving sexp]
(** Parses a labels specification.
* If [spec] contains any of: [' '; ','; '('; ')'], these characters are used as label separators.
Otherwise, every character is a label.
* If [spec] does not contain ["|"] nor ["->"], each label is of the kind [Output].
* If [spec] doesn't contain ["|"], labels to the left of ["->"] are [Input] and to the right [Output].
* Labels to the left of ["|"] are [Batch], and between ["|"] and ["->"] are [Input].
The label ["..."] is only allowed at the first axis of a kind (i.e. last from-end).
It is used to enable broadcasting for the axis kind in the einsum-related shape inference
(like the ellipsis ["..."] in [numpy.einsum]).
The label ["_"] is a place-holder: it is not output to the resulting map but aligns the axes
of other labels. *)
let axis_labels_of_spec spec : parsed_axis_labels =
let check_dot s =
if String.length s > 3 && (Option.is_some @@ String.substr_index ~pos:3 s ~pattern:"...") then
invalid_arg ("axis_labels_of_spec: dot only allowed at first axis of a kind: " ^ spec)
else if String.is_prefix s ~prefix:"..." then (true, String.drop_prefix s 3)
else (false, s)
in
let parse spec in_axes =
let bcast, spec = check_dot @@ String.strip spec in
( bcast,
let on = [ ' '; ','; '('; ')'; '\t'; '\r'; '\n' ] in
let parse_label labels_num from_start s =
let key = AxisKey.{ in_axes; from_end = labels_num - from_start } in
if String.equal s "_" then None
else try Some (key, Either.Second (Dim (Int.of_string s))) with _ -> Some (key, First s)
in
if List.exists ~f:(String.contains spec) on then
let labels = String.split_on_chars spec ~on |> List.filter ~f:(fun s -> not @@ String.is_empty s) in
let labels_num = List.length labels in
(labels_num, List.filter_mapi labels ~f:(parse_label labels_num) |> Map.of_alist_exn (module AxisKey))
else
let labels_num = String.length spec in
( labels_num,
String.to_list spec |> List.map ~f:String.of_char
|> List.filter_mapi ~f:(parse_label labels_num)
|> Map.of_alist_exn (module AxisKey) ) )
in
let batch_spec, spec =
match String.substr_index spec ~pattern:"|" with
| Some end_bch ->
( String.sub ~pos:0 ~len:end_bch spec,
String.sub ~pos:(end_bch + 1) ~len:(String.length spec - end_bch - 1) spec )
| None -> ("", spec)
in
let input_spec, output_spec =
match String.substr_index spec ~pattern:"->" with
| Some end_inp ->
( String.sub ~pos:0 ~len:end_inp spec,
String.sub ~pos:(end_inp + 2) ~len:(String.length spec - end_inp - 2) spec )
| None -> ("", spec)
in
let bcast_batch, (given_batch, batch_labels) = parse batch_spec Batch in
let bcast_input, (given_input, input_labels) = parse input_spec Input in
let bcast_output, (given_output, output_labels) = parse output_spec Output in
let labels =
match Map.append ~lower_part:input_labels ~upper_part:output_labels with
| `Ok m -> (
match Map.append ~lower_part:batch_labels ~upper_part:m with `Ok r -> r | _ -> assert false)
| _ -> assert false
in
{ bcast_batch; bcast_input; bcast_output; given_batch; given_input; given_output; labels }
let einsum_of_spec spec =
let rhs_spec, lhs_spec =
match String.substr_index spec ~pattern:"=>" with
| Some endp ->
( String.sub ~pos:0 ~len:endp spec,
String.sub ~pos:(endp + 2) ~len:(String.length spec - endp - 2) spec )
| None -> ("", spec)
in
let lhs_spec = String.strip lhs_spec in
let rhs_spec = String.strip rhs_spec in
if String.is_empty lhs_spec then invalid_arg ("einsum_of_spec: missing the result spec in " ^ spec);
if String.is_empty rhs_spec then invalid_arg ("einsum_of_spec: missing the argument spec in " ^ spec);
let rhs1_spec, rhs2_spec =
match String.substr_index rhs_spec ~pattern:";" with
| Some endp ->
( String.sub ~pos:0 ~len:endp rhs_spec,
String.sub ~pos:(endp + 1) ~len:(String.length rhs_spec - endp - 1) rhs_spec )
| None -> (rhs_spec, "")
in
let rhs1_spec = String.strip rhs1_spec in
let rhs2_spec = String.strip rhs2_spec in
let lhs_ls = axis_labels_of_spec lhs_spec in
let rhs1_ls = axis_labels_of_spec rhs1_spec in
if String.is_empty rhs2_spec then (rhs1_ls, None, lhs_ls)
else (rhs1_ls, Some (axis_labels_of_spec rhs2_spec), lhs_ls)
(** How to propagate shape updates and do the last update of [Formula.t.shape] when finalizing the formula.
Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *)
type logic =
| Broadcast of compose_type * t * t
(** Matches the shapes for a binary operation, allowing for broadcasting e.g. an axis of dimension 1
does not conflict with a matching axis of a greater dimension.
For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match according
to the [ls1], [ls2] lineup, and the resulting shape inherits the labels according to the [ls3] lineup.
*)
| Transpose of transpose_type * t
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1],
hence the name. *)
| Terminal
[@@deriving sexp]
type update_step = { shape : t; logic : logic } [@@deriving sexp]
(** Data required for a shape inference update step. A step should equilibrate information, passing it both
top-down and bottom-up. The child should be identifiable within the parent via physical equality
(allowing that a child fills both slots of a binary parent). *)
exception Shape_error of string * t * t [@@deriving sexp]
(** Given a fully-inferred shape, maps axes to their corresponding positions in an index using the
[Shape.to_dims] semantics. *)
let axis_keys_to_idcs (sh : t) : int axis_map =
let b_dims =
match sh.batch with
| Unknown -> raise @@ Shape_error ("Batch dimensions still unknown", sh, sh)
| Inferred dims | Given dims | Fixed dims ->
(* Enumerate axes backwards. *)
Array.of_list_mapi dims ~f:(fun i _ -> AxisKey.{ in_axes = Batch; from_end = i + 1 })
in
let i_dims =
match sh.input with
| Unknown -> raise @@ Shape_error ("Input dimensions still unknown", sh, sh)
| Inferred dims | Given dims | Fixed dims ->
Array.of_list_mapi dims ~f:(fun i _ -> AxisKey.{ in_axes = Input; from_end = i + 1 })
in
let o_dims =
match sh.output with
| Unknown -> raise @@ Shape_error ("Output dimensions still unknown", sh, sh)
| Inferred dims | Given dims | Fixed dims ->
Array.of_list_mapi dims ~f:(fun i _ -> AxisKey.{ in_axes = Output; from_end = i + 1 })
in
let idcs = Array.concat [ i_dims; o_dims; b_dims ] in
Array.rev_inplace idcs;
Map.of_alist_exn (module AxisKey) @@ Array.to_list @@ Array.mapi idcs ~f:(fun i key -> (key, i))
(** Converts an axes-keyed map into three arrays of values: batch axes, input axes, output axes.
If the map is incomplete, the result might be invalid: gaps in the array are filled with an arbitrary
one of the provided values. *)
let axis_map_to_dims_bio (type a) ?(default : a option) (idcs : a axis_map) =
if Map.is_empty idcs then ([||], [||], [||])
else
let witness = match default with Some witness -> witness | None -> snd @@ Map.min_elt_exn idcs in
let bch_axes, other =
Map.partition_mapi idcs ~f:(fun ~key:{ in_axes; _ } ~data ->
if AxisKey.is_batch in_axes then Either.First data else Either.Second data)
in
let inp_axes, out_axes =
Map.partition_mapi other ~f:(fun ~key:{ in_axes; _ } ~data ->
if AxisKey.is_input in_axes then Either.First data else Either.Second data)
in
let bch_axes = Map.to_alist bch_axes |> List.map ~f:(fun ({ from_end = i; _ }, v) -> (i, v)) in
let bch_size = List.fold bch_axes ~init:0 ~f:(fun accu (i, _) -> max i accu) in
let bch = Array.create ~len:bch_size witness in
List.iter bch_axes ~f:(fun (i, v) -> bch.(bch_size - i) <- v);
let inp_axes = Map.to_alist inp_axes |> List.map ~f:(fun ({ from_end = i; _ }, v) -> (i, v)) in
let inp_size = List.fold inp_axes ~init:0 ~f:(fun accu (i, _) -> max i accu) in
let inp = Array.create ~len:inp_size witness in
List.iter inp_axes ~f:(fun (i, v) -> inp.(inp_size - i) <- v);
let out_axes = Map.to_alist out_axes |> List.map ~f:(fun ({ from_end = i; _ }, v) -> (i, v)) in
let out_size = List.fold out_axes ~init:0 ~f:(fun accu (i, _) -> max i accu) in
let out = Array.create ~len:out_size witness in
List.iter out_axes ~f:(fun (i, v) -> out.(out_size - i) <- v);
(bch, inp, out)
(** Converts an axes-keyed map into an array of values using the [Shape.to_dims] semantics of axes.
If the map is incomplete and the [~default] is not given, the result might be invalid: gaps in
the array are filled with an arbitrary one of the provided values. *)
let axis_map_to_dims_index (type a) ?(default : a option) (idcs : a axis_map) : a array =
let bch, inp, out = axis_map_to_dims_bio ?default idcs in
Array.concat [ bch; out; inp ]
(** Splits the dimensions of a shape into a map from axes, putting at most one number in a [dims] of
an axis. An empty [dims] list is an end-of-list sentinel: means that there are one fewer axes
of the particular kind. *)
let to_axis_map (sh : t) : dims axis_map =
let kind_dims kind =
match dims_of_kind kind sh with
| Unknown -> [ (AxisKey.{ in_axes = kind; from_end = 1 }, Unknown) ]
| Inferred dims ->
let n_dims = List.length dims in
(AxisKey.{ in_axes = kind; from_end = n_dims + 1 }, Inferred [])
:: List.rev_mapi dims ~f:(fun i d ->
(AxisKey.{ in_axes = kind; from_end = n_dims - i }, Inferred [ d ]))
| Given dims ->
let n_dims = List.length dims in
(AxisKey.{ in_axes = kind; from_end = n_dims + 1 }, Given [])
:: List.rev_mapi dims ~f:(fun i d -> (AxisKey.{ in_axes = kind; from_end = n_dims - i }, Given [ d ]))
| Fixed dims ->
let n_dims = List.length dims in
(AxisKey.{ in_axes = kind; from_end = n_dims + 1 }, Fixed [])
:: List.rev_mapi dims ~f:(fun i d -> (AxisKey.{ in_axes = kind; from_end = n_dims - i }, Fixed [ d ]))
in
let b_dims = kind_dims Batch in
let i_dims = kind_dims Input in
let o_dims = kind_dims Output in
Map.of_alist_exn (module AxisKey) @@ List.concat [ b_dims; i_dims; o_dims ]
(* Design choice: tensor shapes are decided while code is constructed, although not immediately.
Due to mutable updates during shape inference, it is not possible to reuse the same formula with
different shapes. The inference is finalized by invoking the [Formula.subtree_shape_updates] once
on the root formula. *)
(** Generate a label into a broadcasted axis given an einsum-like spec. Axes that are part of the spec
do not count, so that we can use the labels to align axes across different shapes (lhs, rhs1,
rhs2). *)
let gen_label_of_axis ?parsed_spec axis =
let open AxisKey in
let prefix, idx =
match parsed_spec with
| None -> ("_fix_", axis.from_end)
| Some parsed_spec -> ("_", axis.from_end - given_of_kind axis.in_axes parsed_spec)
in
prefix ^ (match axis.in_axes with Batch -> "__b" | Input -> "__i" | Output -> "__o") ^ Int.to_string idx
let set_dims_type sh typ =
sh.batch <- typ (list_of_dims sh.batch);
sh.input <- typ (list_of_dims sh.input);
sh.output <- typ (list_of_dims sh.output)
(** Augment the pseudo-labels map of an einsum notation with the generated labels for broadcasted
axes. *)
let axes_with_inf_labels ~all_labels ls_xhs =
let rec loop more kind accu =
let offset = given_of_kind kind ls_xhs in
let axis = AxisKey.{ in_axes = kind; from_end = offset + more } in
let label = gen_label_of_axis ~parsed_spec:ls_xhs axis in
if not @@ Map.mem all_labels label then accu
else loop (more + 1) kind @@ Map.add_exn accu ~key:axis ~data:(Either.First label)
in
let see kind accu = if bcast_of_kind kind ls_xhs then loop 1 kind accu else accu in
AxisKey.(see Batch @@ see Input @@ see Output @@ ls_xhs.labels)
let axes_with_pseudo_labels =
Map.mapi ~f:(fun ~key ~data ->
match data with Either.First l -> l | Either.Second _ -> gen_label_of_axis key)
let is_given_or_fixed dims = is_given dims || is_fixed dims
type eq_slot = [ `Lhs | `Rhs1 | `Rhs2 ] [@@deriving sexp]
type eqs_map = (string, (AxisKey.t * dims) list, Base.String.comparator_witness) Base.Map.t
type eqs_p_map =
(string, (AxisKey.t * (dims, dim * dims) Either.t) list, Base.String.comparator_witness) Base.Map.t
let sexp_of_eqs_map (map : eqs_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: string * (AxisKey.t * dims) list])
let sexp_of_eqs_p_map (map : eqs_p_map) =
Sexp.List
(Map.to_alist map |> List.map ~f:[%sexp_of: string * (AxisKey.t * (dims, dim * dims) Either.t) list])
type str_str_map = (string, string, Base.String.comparator_witness) Base.Map.t
let sexp_of_str_str_map (map : str_str_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: string * string])
type str_dim_map = (string, dim, Base.String.comparator_witness) Base.Map.t
let sexp_of_str_dim_map (map : str_dim_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: string * dim])
type str_fdim_map = (string, bool * dim, Base.String.comparator_witness) Base.Map.t
let sexp_of_str_fdim_map (map : str_fdim_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: string * (bool * dim)])
type axis_dim_map = (AxisKey.t, dim, AxisKey.comparator_witness) Base.Map.t
let sexp_of_axis_dim_map (map : axis_dim_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: AxisKey.t * dim])
type axis_fdim_map = (AxisKey.t, bool * dim, AxisKey.comparator_witness) Base.Map.t
let sexp_of_axis_fdim_map (map : axis_fdim_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: AxisKey.t * (bool * dim)])
type axis_str_map = (AxisKey.t, string, AxisKey.comparator_witness) Base.Map.t
let sexp_of_axis_str_map (map : axis_str_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: AxisKey.t * string])
type axis_plab_map = (AxisKey.t, (string, dim) Either.t, AxisKey.comparator_witness) Base.Map.t
let sexp_of_axis_plab_map (map : axis_plab_map) =
Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: AxisKey.t * (string, dim) Either.t])
let drop_keeping_parallel subs dims =
let rec loop subs par = function
| (Parallel as p) :: dims when subs > 0 -> loop subs (p :: par) dims
| _ :: dims when subs > 0 -> loop (subs - 1) par dims
| dims -> List.rev_append par dims
in
loop subs [] dims
let take_skipping_parallel subs dims =
let rec loop acc subs = function
| Parallel :: dims when subs > 0 -> loop acc subs dims
| dim :: dims when subs > 0 -> loop (dim :: acc) (subs - 1) dims
| _ -> List.rev acc
in
loop [] subs dims
let take_keeping_parallel ~debug_sh1 ~debug_sh2 ?(indexed_dims = []) subs dims =
let rec loop acc subs = function
| indexed_dims, Parallel :: dims when subs > 0 -> loop (Parallel :: acc) subs (indexed_dims, dims)
| [], dim :: dims when subs > 0 -> loop (dim :: acc) (subs - 1) ([], dims)
| ind_dim :: indexed_dims, dim :: dims when subs > 0 ->
if not @@ (equal_dim (Dim ind_dim) dim || equal_dim (Frozen ind_dim) dim) then
raise (Shape_error ("Dynamic indexing: axis size requirement not met", debug_sh1, debug_sh2));
loop (dim :: acc) (subs - 1) (indexed_dims, dims)
| _ -> List.rev acc
in
loop [] subs (indexed_dims, dims)
let count_parallel_among subs dims =
let rec loop acc subs = function
| Parallel :: dims when subs > 0 -> loop (acc + 1) subs dims
| _ :: dims when subs > 0 -> loop acc (subs - 1) dims
| _ -> acc
in
loop 0 subs dims
(** Performs a local step of shape inference, propagates information into and out of the parent shape
and the child shape(s). *)
let rec propagate_shapes (update : update_step) =
let pointwise_labels debug1 debug2 ls1 ls2 =
Map.merge ls1 ls2 ~f:(fun ~key -> function
| `Both (l1, l2) ->
if String.equal l1 l2 then Some l1
else
let error =
"Axis label mismatch: " ^ l1 ^ " vs " ^ l2 ^ " for " ^ Sexp.to_string_hum
@@ AxisKey.sexp_of_t key
in
raise @@ Shape_error (error, debug1, debug2)
| `Right l | `Left l -> Some l)
in
let broad_dim ~fixed_left ~fixed_right debug1 debug2 axis_key label = function
| d1, d2 when equal_dim d1 d2 -> d1
| Dim 1, d when not fixed_left -> d
| d, Dim 1 when not fixed_right -> d
| d1, d2 ->
let opt_label = match label with None -> "" | Some l -> " (" ^ l ^ ")" in
let error =
"Dimension mismatch for axis " ^ AxisKey.to_string axis_key ^ opt_label ^ ": " ^ dim_to_string d1
^ " vs. " ^ dim_to_string d2
in
raise @@ Shape_error (error, debug1, debug2)
in
(* If initially [lhs] is [Unknown], [rhs1] is [sh1] and [rhs2] is [sh2],
then [lhs] becomes [broadcast_dims sh1 sh2]. *)
let broadcast_dims sh1 sh2 kind labels sh1_dims sh2_dims =
let rec broad_back_dims ~fixed_left ~fixed_right accu i = function
| [], [] -> accu
| [], dims when not fixed_left -> List.rev_append dims accu
| dims, [] when not fixed_right -> List.rev_append dims accu
| [], _ | _, [] ->
let key = AxisKey.{ in_axes = kind; from_end = i } in
let opt_label = match Map.find labels key with None -> "" | Some l -> " (" ^ l ^ ")" in
let error = "Different number of axes around from-end " ^ AxisKey.to_string key ^ opt_label in
raise @@ Shape_error (error, sh1, sh2)
| d1 :: dims1, d2 :: dims2 ->
let key = AxisKey.{ in_axes = kind; from_end = i } in
broad_back_dims ~fixed_left ~fixed_right
(broad_dim ~fixed_left ~fixed_right sh1 sh2 key (Map.find labels key) (d1, d2) :: accu)
(i + 1) (dims1, dims2)
in
let broadcast_dims ~dims1 ~dims2 =
broad_back_dims ~fixed_left:(is_fixed sh1_dims) ~fixed_right:(is_fixed sh2_dims) [] 1
(List.rev dims1, List.rev dims2)
in
match (sh1_dims, sh2_dims) with
| Unknown, Unknown -> Unknown
| (Inferred dims | Given dims | Fixed dims), Unknown | Unknown, (Inferred dims | Given dims | Fixed dims)
->
Inferred dims
| Fixed dims1, Fixed dims2 -> Fixed (broadcast_dims ~dims1 ~dims2)
| (Given dims1 | Fixed dims1), (Given dims2 | Fixed dims2) -> Given (broadcast_dims ~dims1 ~dims2)
| (Inferred dims1 | Given dims1 | Fixed dims1), (Inferred dims2 | Given dims2 | Fixed dims2) ->
Inferred (broadcast_dims ~dims1 ~dims2)
in
let cur_sh = update.shape in
(* Note: does not work with arbitrary permutation as in einsum. *)
let update_labels sh1 to_kind sh2 from_kind =
pointwise_labels sh1 sh2 sh1.axis_labels
@@ Map.map_keys_exn (module AxisKey) ~f:(fun k -> { k with in_axes = to_kind })
@@ Map.filter_keys sh2.axis_labels ~f:AxisKey.(fun k -> equal_kind k.in_axes from_kind)
in
let broadcast_into ?(det = false) to_sh to_kind from_sh from_kind =
match (dims_of_kind to_kind to_sh, dims_of_kind from_kind from_sh) with
| ((Given _ | Fixed _) as into_dims), from_dims ->
ignore @@ broadcast_dims to_sh from_sh to_kind to_sh.axis_labels into_dims from_dims;
into_dims
| into_dims, from_dims -> (
to_sh.axis_labels <- update_labels to_sh to_kind from_sh from_kind;
let result = broadcast_dims to_sh from_sh to_kind to_sh.axis_labels into_dims from_dims in
match (det, from_dims, result) with
| true, Fixed _, Inferred dims -> Fixed dims
| true, Given _, Inferred dims -> Given dims
| _ -> result)
in
let einsum_one_dim_opt debug_spec debug1 debug2 label terms =
List.fold terms ~init:(false, None) ~f:(fun ((is_fixed, dim) as accu) (_axis, dims) ->
match (dim, dims) with
| _, (Inferred (_ :: _ :: _) | Given (_ :: _ :: _) | Fixed (_ :: _ :: _)) -> assert false
| None, Unknown ->
assert (not is_fixed);
(false, None)
| Some _, Unknown -> accu
| None, (Inferred [ dim2 ] | Given [ dim2 ]) ->
assert (not is_fixed);
(false, Some dim2)
| None, Fixed [ dim2 ] ->
assert (not is_fixed);
(true, Some dim2)
| Some dim1, (Inferred [ dim2 ] | Given [ dim2 ]) when equal_dim dim1 dim2 -> accu
| Some dim1, Fixed [ dim2 ] when equal_dim dim1 dim2 -> (true, dim)
| Some (Dim 1), (Inferred [ dim2 ] | Given [ dim2 ]) when not is_fixed -> (false, Some dim2)
| Some dim1, (Inferred [ dim2 ] | Given [ dim2 ] | Fixed [ dim2 ]) ->
raise
@@ Shape_error
( ("Dimension mismatch " ^ dim_to_string dim1 ^ " vs. " ^ dim_to_string dim2
^ " for einsum pseudo-label " ^ label ^ " of " ^ debug_spec
^ if dim_1 dim1 || dim_1 dim2 then " (broadcast prevented)" else ""),
debug1,
debug2 )
| _, Fixed [] ->
raise
@@ Shape_error
( "Too few fixed axes at einsum pseudo-label " ^ label ^ " of " ^ debug_spec
^ " (broadcast prevented)",
debug1,
debug2 )
| _, (Inferred [] | Given []) when is_fixed ->
raise
@@ Shape_error
( "Too few actual axes at einsum pseudo-label " ^ label ^ " of " ^ debug_spec
^ " (broadcast prevented)",
debug1,
debug2 )
| _, (Inferred [] | Given []) -> accu)
in
let einsum_one_dim debug_spec debug1 debug2 ~key ~data =
match einsum_one_dim_opt debug_spec debug1 debug2 key data with
| false, None -> (false, Dim 1 (* which can still be expanded/broadcasted *))
| true, None -> assert false
| is_fixed, Some dim -> (is_fixed, dim)
in
let einsum_to_dims orig_dims is_bcast fdims =
let is_fixed, dims = Array.unzip fdims in
let is_fixed = Array.exists is_fixed ~f:Fn.id in
let dims = Array.to_list dims in
match (orig_dims, is_fixed, is_bcast) with
| _, true, _ -> Fixed dims
| Inferred _, _, true -> Inferred dims
| Fixed _, _, true -> Fixed dims
| _ -> Given dims
in
let eqs_xhs debug_spec debug_sh ls_xhs sh_xhs =
let eqs =
Map.merge ls_xhs.labels sh_xhs ~f:(fun ~key:axis -> function
| `Both (Either.First label, dim) -> Some (label, (axis, dim))
| `Left (First label) -> Some (label, (axis, Inferred []))
| `Both
( Second (Dim at | Frozen at),
( Given [ (Dim dim | Frozen dim) ]
| Fixed [ (Dim dim | Frozen dim) ]
| Inferred [ (Dim dim | Frozen dim) ] ) )
when at >= dim ->
raise
@@ Shape_error ("Specified dimension outside bounds for its axis: " ^ debug_spec, debug_sh, cur_sh)
| `Both (Second _, dim) -> Some (gen_label_of_axis axis, (axis, dim))
| `Left (Second (Dim at)) -> Some (gen_label_of_axis axis, (axis, Inferred [ Dim (at + 1) ]))
| `Left (Second (Frozen at)) -> Some (gen_label_of_axis axis, (axis, Inferred [ Frozen (at + 1) ]))
| `Left (Second Parallel) ->
Some (gen_label_of_axis axis, (axis, Inferred [ Parallel ])) (* FIXME: is it correct? *)
| `Right (Given [] | Fixed [] | Inferred [] | Unknown) -> None
| `Right _dim when not (bcast_of_kind axis.in_axes ls_xhs) ->
raise
@@ Shape_error ("Too many axes to permute -- spec too short: " ^ debug_spec, debug_sh, cur_sh)
(* Note: the too-few-axes error is reported when einsum_one_dim processes the result. *)
| `Right dim -> Some (gen_label_of_axis ~parsed_spec:ls_xhs axis, (axis, dim)))
in
Map.of_alist_multi (module String) @@ Map.data eqs
in
let pseudo_to_labels_xhs xhs_labels sh =
Map.merge xhs_labels sh.axis_labels ~f:(fun ~key:_ -> function
| `Both (pseudo, label) -> Some (pseudo, label) | `Left _pseudo -> None | `Right _label -> assert false)
|> Map.data
|> Map.of_alist_exn (module String)
in
let all_axis_labels debug1 debug2 debug_spec pseudo_to_labels_1 pseudo_to_labels_2 =
Map.merge pseudo_to_labels_1 pseudo_to_labels_2 ~f:(fun ~key:pseudo -> function
| `Both (l1, l2) when String.equal l1 l2 -> Some l1
| `Left l | `Right l -> Some l
| `Both (l1, l2) ->
let error =
"Axis label mismatch: " ^ l1 ^ " vs " ^ l2 ^ " for pseudo label " ^ pseudo ^ " of spec "
^ debug_spec
in
raise @@ Shape_error (error, debug1, debug2))
in
match update.logic with
| Terminal -> ()
| Transpose (Transpose, sh) ->
cur_sh.input <- broadcast_into ~det:true cur_sh Input sh Output;
cur_sh.output <- broadcast_into ~det:true cur_sh Output sh Input;
cur_sh.batch <- broadcast_into ~det:true cur_sh Batch sh Batch;
sh.input <- broadcast_into sh Input cur_sh Output;
sh.output <- broadcast_into sh Output cur_sh Input;
sh.batch <- broadcast_into sh Batch cur_sh Batch
| Transpose (Pointwise_un, sh) ->
cur_sh.input <- broadcast_into ~det:true cur_sh Input sh Input;
cur_sh.output <- broadcast_into ~det:true cur_sh Output sh Output;
cur_sh.batch <- broadcast_into ~det:true cur_sh Batch sh Batch;
sh.input <- broadcast_into sh Input cur_sh Input;
sh.output <- broadcast_into sh Output cur_sh Output;
sh.batch <- broadcast_into sh Batch cur_sh Batch
| Transpose (Permute spec, sh) ->
let ls_rhs, ls_lhs =
match einsum_of_spec spec with
| ls_rhs, None, ls_lhs -> (ls_rhs, ls_lhs)
| _ -> raise @@ Shape_error ("Invalid permutation spec (expected one argument): " ^ spec, sh, cur_sh)
in
let sh_rhs : dims axis_map = to_axis_map sh in
let sh_lhs : dims axis_map = to_axis_map cur_sh in
let eqs_rhs : eqs_map = eqs_xhs spec sh ls_rhs sh_rhs in
let eqs_lhs : eqs_map = eqs_xhs spec sh ls_lhs sh_lhs in
let eqs : eqs_map =
Map.merge eqs_rhs eqs_lhs ~f:(fun ~key:_label -> function
| `Both (rhs, lhs) -> Some (rhs @ lhs) | `Left rhs -> Some rhs | `Right lhs -> Some lhs)
in
let label_dims : str_fdim_map = Map.mapi eqs ~f:(einsum_one_dim spec cur_sh sh) in
let lhs_plabels : axis_plab_map = axes_with_inf_labels ~all_labels:label_dims ls_lhs in
let lhs_labels : axis_str_map = axes_with_pseudo_labels lhs_plabels in
(* To reassign labels across repeated pseudo-labels, we can forget the integers. *)
let pseudo_to_labels_lhs : str_str_map = pseudo_to_labels_xhs lhs_labels cur_sh in
let inferred_lhs : axis_fdim_map = Map.map lhs_labels ~f:(Map.find_exn label_dims) in
let b_lhs, i_lhs, o_lhs = axis_map_to_dims_bio inferred_lhs in
if is_inferred cur_sh.batch || is_unknown cur_sh.batch then
cur_sh.batch <- einsum_to_dims cur_sh.batch ls_lhs.bcast_batch b_lhs;
if is_inferred cur_sh.input || is_unknown cur_sh.input then
cur_sh.input <- einsum_to_dims cur_sh.input ls_lhs.bcast_input i_lhs;
if is_inferred cur_sh.output || is_unknown cur_sh.output then
cur_sh.output <- einsum_to_dims cur_sh.output ls_lhs.bcast_output o_lhs;
let rhs_plabels : axis_plab_map = axes_with_inf_labels ~all_labels:label_dims ls_rhs in
let rhs_labels : axis_str_map = axes_with_pseudo_labels rhs_plabels in
let pseudo_to_labels_rhs : str_str_map = pseudo_to_labels_xhs rhs_labels sh in
let inferred_rhs : axis_fdim_map = Map.map rhs_labels ~f:(Map.find_exn label_dims) in
let b_rhs, i_rhs, o_rhs = axis_map_to_dims_bio inferred_rhs in
if is_inferred sh.batch || is_unknown sh.batch then
sh.batch <- einsum_to_dims sh.batch ls_rhs.bcast_batch b_rhs;
if is_inferred sh.input || is_unknown sh.input then
sh.input <- einsum_to_dims sh.input ls_rhs.bcast_input i_rhs;
if is_inferred sh.output || is_unknown sh.output then
sh.output <- einsum_to_dims sh.output ls_rhs.bcast_output o_rhs;
let all_axis_labels : str_str_map =
all_axis_labels cur_sh sh spec pseudo_to_labels_lhs pseudo_to_labels_rhs
in
let lhs_axis_labels : axis_str_map = Map.filter_map lhs_labels ~f:(Map.find all_axis_labels) in
cur_sh.axis_labels <- lhs_axis_labels;
let rhs_axis_labels : axis_str_map = Map.filter_map rhs_labels ~f:(Map.find all_axis_labels) in
sh.axis_labels <- rhs_axis_labels
| Broadcast (Pointwise_bin, sh1, sh2) ->
let up_labels = pointwise_labels sh1 sh2 sh1.axis_labels sh2.axis_labels in
cur_sh.axis_labels <- up_labels;
(* Note: will not work as expected (propagate givenness/fixedness) if the shape is pre-filled
as [Inferred] instead of [Unknown]. *)
if is_unknown cur_sh.input then
cur_sh.input <- broadcast_dims sh1 sh2 AxisKey.Input up_labels sh1.input sh2.input
else (
cur_sh.input <- broadcast_into cur_sh Input sh1 Input;
cur_sh.input <- broadcast_into cur_sh Input sh2 Input);
if is_unknown cur_sh.output then
cur_sh.output <- broadcast_dims sh1 sh2 AxisKey.Output up_labels sh1.output sh2.output
else (
cur_sh.output <- broadcast_into cur_sh Output sh1 Output;
cur_sh.output <- broadcast_into cur_sh Output sh2 Output);
if is_unknown cur_sh.batch then
cur_sh.batch <- broadcast_dims sh1 sh2 AxisKey.Batch up_labels sh1.batch sh2.batch
else (
cur_sh.batch <- broadcast_into cur_sh Batch sh1 Batch;
cur_sh.batch <- broadcast_into cur_sh Batch sh2 Batch);
sh1.input <- broadcast_into sh1 Input cur_sh Input;
sh1.output <- broadcast_into sh1 Output cur_sh Output;
sh1.batch <- broadcast_into sh1 Batch cur_sh Batch;
sh2.input <- broadcast_into sh2 Input cur_sh Input;
sh2.output <- broadcast_into sh2 Output cur_sh Output;
sh2.batch <- broadcast_into sh2 Batch cur_sh Batch
| Broadcast (Compose, sh1, sh2) ->
(* [sh2] is the value or the function that gets applied first: [cur_sh(x) = sh1(sh2(x))].
I.e. [cur.I = sh2.I, cur.O = sh1.O, sh2.O = sh1.I]. *)
cur_sh.input <- broadcast_into ~det:true cur_sh AxisKey.Input sh2 AxisKey.Input;
cur_sh.output <- broadcast_into ~det:true cur_sh AxisKey.Output sh1 AxisKey.Output;
if is_unknown cur_sh.batch then (
let up_labels = update_labels cur_sh Batch sh1 Batch in
cur_sh.axis_labels <- up_labels;
let up_labels = update_labels cur_sh Batch sh2 Batch in
cur_sh.axis_labels <- up_labels;
cur_sh.batch <- broadcast_dims sh1 sh2 AxisKey.Batch up_labels sh1.batch sh2.batch)
else (
cur_sh.batch <- broadcast_into cur_sh Batch sh1 Batch;
cur_sh.batch <- broadcast_into cur_sh Batch sh2 Batch);
sh1.input <- broadcast_into sh1 Input sh2 Output;
sh1.output <- broadcast_into sh1 Output cur_sh Output;
sh1.batch <- broadcast_into sh1 Batch cur_sh Batch;
sh2.input <- broadcast_into sh2 Input cur_sh Input;
sh2.output <- broadcast_into sh2 Output sh1 Input;
sh2.batch <- broadcast_into sh2 Batch cur_sh Batch;
(* Always re-derive the output shape, to have the latest information. *)
(* TODO: isn't it wasteful to discard the old sh1.output? *)
if not @@ is_not_constrained sh1.deduce_within_shape_constraints then
sh1.output <- deduce_dims sh2.input sh1.deduce_within_shape_constraints
(* TODO(#37):
if not @@ is_not_constrained sh1.deduce_input_from_output then
sh1.input <- deduce_dims sh2.output sh1.deduce_input_from_output *)
| Broadcast (Einsum spec, sh1, sh2) ->
let ls_rhs1, ls_rhs2, ls_lhs =
match einsum_of_spec spec with
| ls_rhs1, Some ls_rhs2, ls_lhs -> (ls_rhs1, ls_rhs2, ls_lhs)
| _ -> raise @@ Shape_error ("Invalid einsum spec (expected two arguments): " ^ spec, sh1, sh2)
in
let sh_rhs1 : dims axis_map = to_axis_map sh1 in
let sh_rhs2 : dims axis_map = to_axis_map sh2 in
let sh_lhs : dims axis_map = to_axis_map cur_sh in
let eqs_rhs1 : eqs_map = eqs_xhs spec sh1 ls_rhs1 sh_rhs1 in
let eqs_rhs2 : eqs_map = eqs_xhs spec sh2 ls_rhs2 sh_rhs2 in
let eqs_lhs : eqs_map = eqs_xhs spec sh1 ls_lhs sh_lhs in
let side_eq side (axis, dims) = ((side, axis), dims) in
let eqs =
Map.merge eqs_rhs1 eqs_lhs ~f:(fun ~key:_label -> function
| `Both (rhs, lhs) ->
Some (List.rev_map_append rhs ~f:(side_eq `Rhs1) @@ List.map lhs ~f:(side_eq `Lhs))
| `Left rhs -> Some (List.map rhs ~f:(side_eq `Rhs1))
| `Right lhs -> Some (List.map lhs ~f:(side_eq `Lhs)))
in
let eqs =
Map.merge eqs_rhs2 eqs ~f:(fun ~key:_label -> function
| `Both (rhs, more) -> Some (List.rev_map_append rhs ~f:(side_eq `Rhs2) more)
| `Left rhs -> Some (List.map rhs ~f:(side_eq `Rhs2))
| `Right more -> Some more)
in
let label_dims : str_fdim_map = Map.mapi eqs ~f:(einsum_one_dim spec sh1 sh2) in
let lhs_plabels : axis_plab_map = axes_with_inf_labels ~all_labels:label_dims ls_lhs in
let lhs_labels : axis_str_map = axes_with_pseudo_labels lhs_plabels in
let pseudo_to_labels_lhs : str_str_map = pseudo_to_labels_xhs lhs_labels cur_sh in
let inferred_lhs : axis_fdim_map = Map.map lhs_labels ~f:(Map.find_exn label_dims) in
let b_lhs, i_lhs, o_lhs = axis_map_to_dims_bio inferred_lhs in
if is_inferred cur_sh.batch || is_unknown cur_sh.batch then
cur_sh.batch <- einsum_to_dims cur_sh.batch ls_lhs.bcast_batch b_lhs;
if is_inferred cur_sh.input || is_unknown cur_sh.input then
cur_sh.input <- einsum_to_dims cur_sh.input ls_lhs.bcast_input i_lhs;
if is_inferred cur_sh.output || is_unknown cur_sh.output then
cur_sh.output <- einsum_to_dims cur_sh.output ls_lhs.bcast_output o_lhs;
let rhs1_plabels : axis_plab_map = axes_with_inf_labels ~all_labels:label_dims ls_rhs1 in
let rhs1_labels : axis_str_map = axes_with_pseudo_labels rhs1_plabels in
let pseudo_to_labels_rhs1 : str_str_map = pseudo_to_labels_xhs rhs1_labels sh1 in
let inferred_rhs1 : axis_fdim_map = Map.map rhs1_labels ~f:(Map.find_exn label_dims) in
let b_rhs1, i_rhs1, o_rhs1 = axis_map_to_dims_bio inferred_rhs1 in
if is_inferred sh1.batch || is_unknown sh1.batch then
sh1.batch <- einsum_to_dims sh1.batch ls_rhs1.bcast_batch b_rhs1;
if is_inferred sh1.input || is_unknown sh1.input then
sh1.input <- einsum_to_dims sh1.input ls_rhs1.bcast_input i_rhs1;
if is_inferred sh1.output || is_unknown sh1.output then
sh1.output <- einsum_to_dims sh1.output ls_rhs1.bcast_output o_rhs1;
let rhs2_plabels : axis_plab_map = axes_with_inf_labels ~all_labels:label_dims ls_rhs2 in
let rhs2_labels : axis_str_map = axes_with_pseudo_labels rhs2_plabels in
let pseudo_to_labels_rhs2 : str_str_map = pseudo_to_labels_xhs rhs2_labels sh2 in
let inferred_rhs2 : axis_fdim_map = Map.map rhs2_labels ~f:(Map.find_exn label_dims) in
let b_rhs2, i_rhs2, o_rhs2 = axis_map_to_dims_bio inferred_rhs2 in
if is_inferred sh2.batch || is_unknown sh2.batch then
sh2.batch <- einsum_to_dims sh2.batch ls_rhs2.bcast_batch b_rhs2;
if is_inferred sh2.input || is_unknown sh2.input then
sh2.input <- einsum_to_dims sh2.input ls_rhs2.bcast_input i_rhs2;
if is_inferred sh2.output || is_unknown sh2.output then
sh2.output <- einsum_to_dims sh2.output ls_rhs2.bcast_output o_rhs2;
let all_axis_labels1 : str_str_map =
all_axis_labels cur_sh sh1 spec pseudo_to_labels_lhs pseudo_to_labels_rhs1
in
let all_axis_labels : str_str_map =
all_axis_labels cur_sh sh2 spec all_axis_labels1 pseudo_to_labels_rhs2
in
let lhs_axis_labels : axis_str_map = Map.filter_map lhs_labels ~f:(Map.find all_axis_labels) in
cur_sh.axis_labels <- lhs_axis_labels;
let rhs1_axis_labels : axis_str_map = Map.filter_map rhs1_labels ~f:(Map.find all_axis_labels) in
sh1.axis_labels <- rhs1_axis_labels;
let rhs2_axis_labels : axis_str_map = Map.filter_map rhs2_labels ~f:(Map.find all_axis_labels) in
sh2.axis_labels <- rhs2_axis_labels
| Broadcast (Dynamic_index { over_kind; from_left; other_axes_pointwise; indexed_dims }, sh1, sh2) ->
let subs =
match List.last @@ list_of_dims sh2.output with
| None ->
sh2.output <- Given [ Dim 1 ];
1
| Some Parallel -> !num_parallel_tasks
| Some (Frozen d | Dim d) -> d
in
let check_no_axes indexed_dims =
if List.length indexed_dims <> subs then
raise
(Shape_error
( "Dynamic indexing: provided dimensions and indices assume a different number of axes",
sh1,
sh2 ))
in
let reduced_sh2 =
{
sh2 with
output = map_dims sh2.output ~f:List.drop_last_exn;
axis_labels = shift_axes_of_kind AxisKey.Output sh2 ~f:(( - ) 1);
}
in
let list_dims sh = list_of_dims @@ dims_of_kind over_kind sh in
if
Option.is_some indexed_dims
&& List.is_empty (list_dims sh1)
&& is_unknown (dims_of_kind over_kind cur_sh)
then (* Wait for more information. *) ()
else if Option.is_none indexed_dims && List.length (list_dims sh1) < subs then
raise (Shape_error ("Insufficient indexed axes information for dynamic indexing", sh1, sh2))
else if from_left then (
let n_par_axes =
match (indexed_dims, dims_of_kind over_kind sh1) with
| Some indexed_dims, (Unknown | Inferred []) ->
check_no_axes indexed_dims;
(* Infer and fix the dimensions -- preserve "indexing from the left". *)
let indexed_dims = List.map ~f:dim indexed_dims in
let lhs_dims = list_dims cur_sh in
update_kind ~f:(fun _ -> Given (indexed_dims @ lhs_dims)) over_kind sh1;
0
| None, _ ->
let dims = list_dims sh1 in
(* Fix the dimensions: we don't want new axes inferred to the left of indexed axes. *)
update_kind ~f:(fun _ -> Given dims) over_kind sh1;
count_parallel_among subs dims
| Some indexed_dims, _ ->
check_no_axes indexed_dims;
let dims = list_dims sh1 in
(* Verify the dimensions of the indexed axes. *)
let _indexed_dims_with_par =
take_keeping_parallel ~debug_sh1:sh1 ~debug_sh2:sh2 ~indexed_dims subs dims
in
(* Fix the dimensions: we don't want new axes inferred to the left of indexed axes. *)
update_kind ~f:(fun _ -> Given dims) over_kind sh1;
count_parallel_among subs dims
in
let sh1_size = List.length @@ list_dims sh1 in
let reduced_dims over_dims = map_dims over_dims ~f:(drop_keeping_parallel subs) in
let reduced_sh1 = map_over_kind over_kind ~f:reduced_dims sh1 in
let drop_left from_end = if from_end > sh1_size - subs - n_par_axes then -1 else from_end in
(* TODO: unfortunately we ignore labels on parallel axes. *)
let reduced_sh1 = { reduced_sh1 with axis_labels = shift_axes_of_kind over_kind sh1 ~f:drop_left } in
let extended_sh, logic =
if other_axes_pointwise then (None, Broadcast (Pointwise_bin, reduced_sh1, reduced_sh2))
else