forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_vmap.py
1096 lines (901 loc) · 44.7 KB
/
test_vmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
from torch import vmap
import functools
import warnings
class TestVmapAPI(TestCase):
def test_non_tensor_output_raises(self):
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as the return"):
output = vmap(lambda x: 3.14)(torch.ones(3))
def multiple_outputs(x):
return x, 3
with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
vmap(multiple_outputs)(torch.ones(3))
def test_different_map_dim_size_raises(self):
x = torch.randn(2)
y = torch.randn(3)
expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(torch.mul)(x, y)
def test_func_with_no_inputs(self):
expected_msg = 'got no inputs'
def foo():
return torch.randn(3)
def bar(x):
return torch.randn(3)
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(foo)()
with self.assertRaisesRegex(ValueError, expected_msg):
vmap(bar)()
def test_constant_function(self):
output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
def test_single_input(self):
x = torch.randn(2, 3)
def square(x):
return x * x
output = vmap(square)(x)
self.assertEqual(output, x * x)
def test_multiple_inputs(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
output = vmap(torch.mul)(x, y)
self.assertEqual(output, x * y)
def test_multiple_outputs(self):
def foo(x):
return x * x, x * x * x
x = torch.randn(3)
outputs = vmap(foo)(x)
self.assertEqual(outputs[0], x * x)
self.assertEqual(outputs[1], x * x * x)
def test_multiple_outputs_error_cases(self):
# This is the same thing as
# def returns_tuple_of_tensors(x):
# return x, x
def returns_tuple_of_tensors(x):
return (x, x)
def returns_list_of_two_tensors(x):
return [x, x]
def returns_list_of_one_tensor(x):
return [x]
x = torch.randn(3)
# should not throw
vmap(returns_tuple_of_tensors)(x)
# jax supports these, but we don't yet
msg = "must only return Tensors, got type <class 'list'>"
with self.assertRaisesRegex(ValueError, msg):
vmap(returns_list_of_two_tensors)(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(returns_list_of_one_tensor)(x)
def test_nested_with_same_map_dim(self):
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
output = vmap(vmap(torch.mul))(x, y)
self.assertEqual(output, x * y)
output = vmap(vmap(vmap(torch.mul)))(x, y)
self.assertEqual(output, x * y)
def test_nested_with_different_map_dim(self):
x = torch.randn(2, 3)
y = torch.randn(5, 3)
output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
self.assertEqual(output.shape, (2, 5, 3))
self.assertEqual(output, x.view(2, 1, 3) * y)
z = torch.randn(7, 3)
output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
self.assertEqual(output.shape, (2, 5, 7, 3))
self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
def test_noop_in_inner_vmap(self):
x = torch.randn(3)
y = torch.randn(5)
output = vmap(lambda x: vmap(lambda y: x)(y))(x)
self.assertEqual(output, x.view(3, 1).expand(3, 5))
def test_unsupported_op_err_msg(self):
# Unsupported view op
tensor = torch.randn(2, 3)
with self.assertRaisesRegex(RuntimeError, "doesn't work on in-place or view ops"):
vmap(torch.as_strided, (0, None, None))(tensor, [2, 3], [0, 0])
# The fallback doesn't support TensorList
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
vmap(lambda t: torch.stack([t]))(tensor)
# Don't support non-tensor returns. This is a limitation of vmap;
# functions that don't return tensors must be special cased
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
vmap(torch.Tensor.item)(tensor)
def test_unsupported_inplace_op_err_msg(self):
def foo(x):
return x.cos_()
x = torch.randn(3)
with self.assertRaisesRegex(
RuntimeError, 'Batching rule not implemented'):
vmap(foo)(x)
def test_nonzero_out_dims(self):
# Basic test
tensor = torch.randn(2, 3)
result = vmap(lambda x: x, out_dims=1)(tensor)
self.assertEqual(result, tensor.permute(1, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# Test that the batch dimension gets permuted to dim 2
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# negative out_dim
tensor = torch.randn(2, 3, 5, 7)
result = vmap(lambda x: x, out_dims=-1)(tensor)
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
self.assertEqual(result.data_ptr(), tensor.data_ptr())
# check that out_dims works on ALL outputs
tensor = torch.randn(2, 3, 5, 7)
other = torch.randn(2, 3, 5, 7)
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
# use out_dims with the maximum vmap-able tensor dims (64 dims)
ndims = 64
shape = [2] + [1] * (ndims - 1)
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
tensor = torch.randn(shape)
result = vmap(lambda x: x, out_dims=2)(tensor)
self.assertEqual(result.shape, expected_shape)
# test something that is not the identity function
def foo(x, y):
return x, x * y, x * y * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=1)(x, y)
self.assertEqual(
result,
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
def test_multiple_out_dims(self):
def foo(x):
return x, x
def bar(x, y):
return x, x, x, x * y
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
result = vmap(foo, out_dims=(0, 1))(x)
self.assertEqual(result, (x, x.permute(1, 0, 2)))
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
expected = (
x.permute(1, 2, 0),
x,
x.permute(1, 0, 2),
(x * y).permute(1, 2, 0),
)
self.assertEqual(result, expected)
def test_nested_out_dims(self):
y = torch.randn(2, 3, 5, 7)
# Inner vmap has non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
self.assertEqual(result.shape, (2, 5, 3, 7))
self.assertEqual(result, y.permute(0, 2, 1, 3))
# all vmaps have non-zero out_dim
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
self.assertEqual(result.shape, (5, 2, 3, 7))
self.assertEqual(result, y.permute(2, 0, 1, 3))
# throwing in some negative out_dims
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
self.assertEqual(result.shape, (5, 7, 3, 2))
self.assertEqual(result, y.permute(2, 3, 1, 0))
# testing fn that isn't the identity
x = torch.randn(2, 3)
y = torch.randn(5, 3)
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
self.assertEqual(result.shape, (3, 2, 5))
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
def test_out_dims_edge_case(self):
def foo(x):
return x
# Test that we accept out_dims=(1,) for a function with one output.
tensor = torch.randn(2, 3)
expected = vmap(foo, out_dims=1)(tensor)
result = vmap(foo, out_dims=(1,))(tensor)
self.assertEqual(result, expected)
def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
msg = '`out_dims` must be an int or a tuple of int'
tensor = torch.randn(2, 3)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims='lol')(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=('lol',))(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=None)(tensor)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(None,))(tensor)
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
msg = '`out_dims` must have one dim per output'
x = torch.randn(2, 3, 5)
# Too many out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: x, out_dims=(0, 0))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
# Too few out_dims
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x), out_dims=(0,))(x)
with self.assertRaisesRegex(ValueError, msg):
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
def test_out_dim_out_of_bounds_err_msg(self):
# TODO(rzou): This error message isn't that great. It comes straight
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
# the error message in the future in C++
msg = 'Dimension out of range'
x = torch.randn(2, 3, 5)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=3)(x)
with self.assertRaisesRegex(IndexError, msg):
vmap(lambda x: x, out_dims=-4)(x)
def test_non_zero_in_dims(self):
tensor = torch.randn(2, 3, 5)
# Implicit out_dims = 0; vmap will move the batch dim to the front.
output = vmap(lambda x: x, (1,))(tensor)
self.assertEqual(output, tensor.permute(1, 0, 2))
self.assertEqual(output.data_ptr(), tensor.data_ptr())
x = torch.randn(2, 3)
y = torch.randn(3, 2)
output = vmap(torch.mul, (0, 1))(x, y)
self.assertEqual(output, x * y.t())
output = vmap(torch.mul, (1, 0))(x, y)
self.assertEqual(output, x.t() * y)
def test_none_in_dims(self):
x = torch.randn(2, 3)
y = torch.randn(2, 3)
# None in_dim for a Tensor means we don't map over it
output = vmap(torch.mul, (0, None))(x, y)
self.assertEqual(output.shape, (2, 2, 3))
self.assertEqual(output, x.view(2, 1, 3) * y)
# None in_dim for non-tensor arguments
output = vmap(torch.mul, (0, None))(x, 2)
self.assertEqual(output, x * 2)
def test_nested_non_default_in_dims(self):
x = torch.rand(5, 2, 3)
y = torch.rand(3, 5, 2)
result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
def test_non_default_in_dims_out_dims(self):
x = torch.randn(2, 3, 5)
# Same in_dim as out_dim, vmap over identity
result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
self.assertEqual(result, x)
self.assertEqual(result.data_ptr(), x.data_ptr())
# Different in_dim from out_dim, vmap over identity
result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
self.assertEqual(result.shape, (2, 5, 3))
self.assertEqual(result, x.transpose(1, 2))
self.assertEqual(result.data_ptr(), x.data_ptr())
def foo(x):
return x * 2
# Same in_dim as out_dim, vmap over operation
result = vmap(foo, in_dims=1, out_dims=1)(x)
self.assertEqual(result, x * 2)
# Different in_dim as out_dim, vmap over operation
result = vmap(foo, in_dims=2, out_dims=1)(x)
self.assertEqual(result.shape, (2, 5, 3))
self.assertEqual(result, (x * 2).transpose(1, 2))
# Basic nested test.
result = vmap(vmap(foo, 1, 1), 1, 1)(x)
self.assertEqual(result, x * 2)
def test_in_dims_wrong_type_err_msg(self):
x = torch.randn(3)
y = torch.randn(3)
msg = 'expected `in_dims` to be int or tuple'
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.mul, [0, 0])(x, y)
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.mul, set({0, 0}))(x, y)
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.mul, 'lol')(x, y)
# The following should not throw
vmap(torch.mul, (0, 0))(x, y)
def test_not_enough_in_dims_err_msg(self):
x = torch.randn(3)
y = torch.randn(3)
msg = r'expected one `in_dim` per input \(got \w+ inputs\)'
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.mul, (0,))(x, y)
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.mul, (0, 0, 0))(x, y)
# The following should not throw
vmap(torch.mul, (0, 0))(x, y)
def test_in_dims_must_be_flat_tuple_err_msg(self):
msg = 'in_dims must be a flat tuple containing ints and/or Nones'
x = torch.randn(3)
y = torch.randn(3)
z = torch.randn(3)
def foo(xy):
return xy[0] * xy[1]
def bar(x, yz):
return x * yz[0] * yz[1]
# NB: jax supports all of the following, we don't yet.
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, ((0, 0),))((x, y))
with self.assertRaisesRegex(ValueError, msg):
vmap(bar, (0, (0, 0)))(x, (y, z))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, ({0: 0, 1: 0},))({0: x, 1: y})
def test_integer_in_dim_but_not_tensor_input_err_msg(self):
def foo(xy):
return xy[0] * xy[1]
def bar(x, yz):
return x * yz[0] * yz[1]
x = torch.randn(2, 3)
y = torch.randn(2, 3)
# jax supports these, we too can in the future.
msg = 'Got in_dim=0 for input 0, but input 0 is not a Tensor'
with self.assertRaisesRegex(ValueError, msg):
vmap(foo)((x, y))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, (0,))((x, y))
# jax supports these as well, we too can in the future.
msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor'
with self.assertRaisesRegex(ValueError, msg):
vmap(foo)(x, (x, y))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, (0, 0))(x, (x, y))
# the following are errors in jax (and will always be errors)
msg = 'Got in_dim=0 for input 1, but input 1 is not a Tensor'
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.sum)(x, 0)
with self.assertRaisesRegex(ValueError, msg):
vmap(torch.sum, (0, 0))(x, 0)
# The following should not throw
vmap(torch.sum, (0, None))(x, 0)
def test_in_dim_not_in_tensor_err_msg(self):
def foo(x):
return x * x
msg = r'Got in_dim=-?\w for input 0, but input 0 is a Tensor of dimensionality \w'
with self.assertRaisesRegex(ValueError, msg):
vmap(foo)(torch.randn([]))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, in_dims=(0,))(torch.randn([]))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, in_dims=(-1,))(torch.randn(2, 3))
with self.assertRaisesRegex(ValueError, msg):
vmap(foo, in_dims=(2,))(torch.randn(2, 3))
# the following should not throw
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
vmap(foo, in_dims=(1,))(torch.randn(2, 3))
def _assert_uses_vmap_fallback(self, vmap_args, inputs):
with warnings.catch_warnings(record=True) as wa:
result = vmap(*vmap_args)(*inputs)
self.assertEqual(len(wa), 2)
self.assertRegex(str(wa[-1].message),
r'falling back to slow \(for loop and stack\) implementation')
def test_fallback_sub(self):
# NB: One day we will implement a batching rule for torch.sub.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
x = torch.randn(5, 7, 11)
y = torch.randn(5, 7, 11)
self._assert_uses_vmap_fallback((torch.sub,), (x, y))
# fallback on torch.sub
x = torch.randn(7, 11, 5)
y = torch.randn(5, 7, 11)
result = vmap(torch.sub, (2, 0))(x, y)
self.assertEqual(result, x.permute(2, 0, 1) - y)
# fallback on torch.sub, nested vmap
x = torch.randn(7, 11, 5)
y = torch.randn(5, 7, 11)
result = vmap(vmap(torch.sub), (2, 0))(x, y)
self.assertEqual(result, x.permute(2, 0, 1) - y)
# big batch size (total 10000)
x = torch.randn(100, 10, 10, 5)
y = torch.randn(100, 10, 10)
result = vmap(vmap(vmap(torch.sub)))(x, y)
self.assertEqual(result, x - y.view(100, 10, 10, 1))
def test_fallback_masked_fill(self):
# NB: One day we will implement a batching rule for masked_fill
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
def run_test(batch_size):
B0 = batch_size
x = torch.randn(B0, 7, 11, 13)
dim = 0
index = torch.tensor([0, 4, 2])
values = torch.randn(B0, 3, 13)
self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))
result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
expected = torch.index_add(
x, dim + 1, index, values.view(B0, 3, 1, 13))
self.assertEqual(result, expected)
run_test(batch_size=5)
run_test(batch_size=1237)
def test_fallback_multiple_returns(self):
# NB: One day we will implement a batching rule for torch.var_mean
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
B0, B1, B2 = 2, 3, 1237
tensor = torch.randn(B0, 10)
self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
# fallback correctness on torch.var_mean
result = vmap(torch.var_mean)(tensor)
expected = torch.var_mean(tensor, dim=1)
self.assertEqual(result, expected)
# nested vmap
tensor = torch.randn(B0, B1, 10)
result = vmap(vmap(torch.var_mean))(tensor)
expected = torch.var_mean(tensor, dim=2)
self.assertEqual(result, expected)
# big batch size, nested vmap
tensor = torch.randn(B0, B1, B2, 10)
result = vmap(vmap(vmap(torch.var_mean)))(tensor)
expected = torch.var_mean(tensor, dim=3)
self.assertEqual(result, expected)
def test_backward_unsupported_interaction(self):
x = torch.randn(3, requires_grad=True)
y = torch.randn(5)
grad = torch.randn_like(x)
err_msg = r'backward\(\) called inside torch.vmap'
def backward_on_vmapped_tensor(x):
x.sum().backward()
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(backward_on_vmapped_tensor)(x)
def backward_with_vmapped_grad(x, grad):
x.backward(grad)
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(backward_with_vmapped_grad)(x, grad)
def completely_unrelated_backward(y):
x.sum().backward()
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(completely_unrelated_backward)(y)
def test_grad_unsupported_interaction(self):
input_tensor = torch.randn(3, requires_grad=True)
err_msg = 'autograd.grad.* called inside torch.vmap'
captured = torch.randn(3, requires_grad=True)
def output_to_grad_is_vmapped(input_tensor):
output = (captured * input_tensor).sum()
return torch.autograd.grad([output], [captured])[0]
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(output_to_grad_is_vmapped)(input_tensor)
output = (input_tensor ** 2).sum()
def input_to_grad_is_vmapped(input_tensor):
return torch.autograd.grad([output], [input_tensor])[0]
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(input_to_grad_is_vmapped)(input_tensor)
def test_batched_gradient_basic(self):
N = 3
x = torch.randn(N, requires_grad=True)
y = torch.randn(N)
def vjp_mul(v):
return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
batched_v = torch.eye(N)
jacobian = vmap(vjp_mul)(batched_v)
self.assertEqual(jacobian, torch.diagflat(y))
def test_functools_partial(self):
x = torch.randn(3)
y = torch.randn(2, 3)
result = vmap(functools.partial(torch.mul, x))(y)
self.assertEqual(result, x * y)
def test_nn_module(self):
tensor = torch.randn(2, 3)
model = torch.nn.Linear(3, 3, bias=False)
result = vmap(model)(tensor)
self.assertEqual(result, model(tensor))
def slice_inputs(inputs, bdims, i):
result = []
for inp, bdim in zip(inputs, bdims):
if bdim is None:
result.append(inp)
else:
result.append(inp.select(bdim, i))
return tuple(result)
def reference_vmap(op, inputs, in_dims=0, out_dims=0):
if isinstance(in_dims, int):
in_dims = (in_dims,) * len(inputs)
bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
bdim_size = bdim_sizes[0]
results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
assert len(results) > 0
op_has_single_return = not isinstance(results[0], tuple)
if op_has_single_return:
assert all(isinstance(result, torch.Tensor) for result in results)
if isinstance(out_dims, int):
out_dims = (out_dims,) * 1
return torch.stack(results, dim=out_dims[0])
assert all(isinstance(result, tuple) for result in results)
num_returns = len(results[0])
assert all(len(result) == num_returns for result in results)
if isinstance(out_dims, int):
out_dims = (out_dims,) * num_returns
return tuple(torch.stack(result_shards, out_dim)
for result_shards, out_dim in zip(zip(*results), out_dims))
class TensorFactory:
@staticmethod
def rand(size, device='cpu', dtype=torch.float):
return torch.rand(size, device=device, dtype=dtype)
@staticmethod
def randn(size, device='cpu', dtype=torch.float):
return torch.randn(size, device=device, dtype=dtype)
@staticmethod
def randp1(size, device='cpu', dtype=torch.float):
return torch.rand(size, device=device, dtype=dtype) + 1
class TestVmapOperators(TestCase):
def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
check_view=False, check_propagates_grad=True):
result = vmap(op, in_dims, out_dims)(*inputs)
reference_result = reference_vmap(op, inputs, in_dims, out_dims)
self.assertEqual(result, reference_result)
op_has_single_return = not isinstance(result, tuple)
if check_view:
result_as_tuple = (result,) if op_has_single_return else result
for output in result_as_tuple:
self.assertEqual(output.data_ptr() - output.storage_offset() * output.element_size(),
inputs[0].data_ptr(),
msg="result was not a view of the first input!")
if not check_propagates_grad:
return
# Assuming input[0] is a floating-point tensor. Check if the vmap
# operation propagates the requires_grad flag to the zeroth output.
# Some vmap operators are implemented in a way that assumes that
# they are composite with respect to autograd. If the operator ever is
# changed to not be composite with respect to autograd, then the
# following check should fail.
inputs_clone = list(inputs)
inputs_clone[0] = inputs[0].clone().requires_grad_()
result = vmap(op, in_dims, out_dims)(*inputs_clone)
result_as_tuple = (result,) if op_has_single_return else result
self.assertTrue(result[0].requires_grad)
def _vmap_view_test(self, *args, **kwargs):
self._vmap_test(*args, **kwargs, check_view=True)
def _assert_doesnt_use_vmap_fallback(self, vmap_args, inputs):
regex = r'falling back to slow \(for loop and stack\) implementation'
with warnings.catch_warnings(record=True) as wa:
result = vmap(*vmap_args)(*inputs)
for captured_warning in wa:
self.assertNotRegex(str(captured_warning.message), regex)
def test_assert_doesnt_use_vmap_fallback(self):
with self.assertRaises(AssertionError):
# One day we'll implement a batching rule for torch.var_mean.
# When that happens, please change the example to use an
# operator that doesn't have a batching rule implemented.
self._assert_doesnt_use_vmap_fallback([torch.var_mean], [torch.rand(3)])
def _test_unary(self, op, getter, device):
test = self._vmap_test
B0, B1 = 7, 11
self._assert_doesnt_use_vmap_fallback([op], [getter([B0], device)])
# Single vmap, various in_dims / out_dims
test(op, [getter([B0, 3], device)])
test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
# Doubly nested vmap
test(vmap(op), [getter([B0, B1], device)])
test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
in_dims=2, out_dims=2)
def test_unary_pointwise_ops(self):
cases = [
(torch.abs, TensorFactory.randn),
(torch.acos, TensorFactory.rand),
(torch.asin, TensorFactory.rand),
(torch.atan, TensorFactory.rand),
(torch.ceil, TensorFactory.randn),
(torch.cos, TensorFactory.rand),
(torch.cosh, TensorFactory.rand),
(torch.digamma, TensorFactory.rand),
(torch.exp, TensorFactory.randn),
(torch.expm1, TensorFactory.randn),
(torch.floor, TensorFactory.randn),
(torch.frac, TensorFactory.randn),
(torch.lgamma, TensorFactory.rand),
(torch.log, TensorFactory.randp1),
(torch.log10, TensorFactory.randp1),
(torch.log1p, TensorFactory.randp1),
(torch.log2, TensorFactory.randp1),
(torch.neg, TensorFactory.randn),
(torch.reciprocal, TensorFactory.randp1),
(torch.relu, TensorFactory.randn),
(torch.round, TensorFactory.randn),
(torch.rsqrt, TensorFactory.randp1),
(torch.sigmoid, TensorFactory.randn),
(torch.sign, TensorFactory.randn),
(torch.sin, TensorFactory.rand),
(torch.sinh, TensorFactory.rand),
(torch.sqrt, TensorFactory.rand),
(torch.tan, TensorFactory.rand),
(torch.tanh, TensorFactory.rand),
(torch.trunc, TensorFactory.randn),
]
for op, getter in cases:
self._test_unary(op, getter, 'cpu')
def test_binary_pointwise_ops(self):
def get_number(getter):
return getter([]).item()
cases = [
(torch.mul, TensorFactory.rand),
(lambda x, y: x * y, TensorFactory.randn),
]
test = self._vmap_test
for op, getter in cases:
device = 'cpu'
B0, B1 = 7, 11
self._assert_doesnt_use_vmap_fallback(
[op], (getter([B0], device), getter([B0], device)))
# Single vmap: op(Tensor, Tensor)
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
test(op, (getter([B0], device), getter([B0, 2, 3], device)))
test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
test(op, (getter([B0], device), getter([2, B0, 3], device)),
in_dims=(0, 1), out_dims=1)
test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
# Nested vmap: op(Tensor, Tensor)
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
test(vmap(op, in_dims=(None, 0)),
(getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
# Python number overload: op(Tensor, Number) (and vice-versa)
number = get_number(getter)
self._test_unary(lambda t: op(t, number), getter, device)
number = get_number(getter)
self._test_unary(lambda t: op(number, t), getter, device)
# Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
test(op, (getter([B0], device), getter([B0], device)))
# Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
if not torch.cuda.is_available():
return
# Test cross-device scalars
number = get_number(getter)
self._test_unary(lambda t: op(t, number), getter, device='cuda')
self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
def test_chunk(self):
test = self._vmap_view_test
op = torch.chunk
B0, B1, B2 = 7, 11, 13
# tests for torch.split(self, split_size: int, dim)
test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
def test_diagonal(self):
tensor = torch.randn(3, 5, 7, 11, 13)
test = self._vmap_view_test
op = torch.diagonal
test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
(tensor,), in_dims=1, out_dims=1)
def test_expand_as(self):
op = torch.Tensor.expand_as
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
def test_movedim(self):
op = torch.movedim
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
# movedim(tensor, int, int) variant
test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None))
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
(torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None))
# movedim(tensor, intlist, intlist) variant
test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)),
(torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None))
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
(torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))
def test_narrow(self):
op = torch.narrow
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
test(vmap(op, in_dims=(0, None, None, None)),
(torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
(torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
def test_select(self):
op = torch.select
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
def test_slice(self):
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
(torch.rand(3, 5, B0, B1, B2),), in_dims=2)
def test_reshape(self):
test = self._vmap_test
B0, B1, B2 = 7, 11, 13
op = torch.reshape
test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
(torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
def test_reshape_as(self):
test = self._vmap_test
B0, B1, B2 = 7, 11, 13
op = torch.Tensor.reshape_as
test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)
test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)
test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
(torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
in_dims=(5, 0), check_view=False)
def test_split(self):
test = self._vmap_view_test
op = torch.split
B0, B1, B2 = 7, 11, 13
# tests for torch.split(self, split_size: int, dim)
test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
# tests for torch.split(self, split_size: List[int], dim)
test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
in_dims=(2, None, None))
test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
def test_t(self):
op = torch.t
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 5),))
test(op, (torch.rand(2, B0, 5),), in_dims=1)
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
def test_T_numpy(self):
def op(t):
return t.T
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 3, 5),))
test(op, (torch.rand(B0),))
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
def test_to(self):
test = self._vmap_test
B0, B1 = 7, 11
test(lambda t: t.to('cpu'), (torch.rand(B0),))
test(lambda t: t.to(torch.double), (torch.rand(B0),))
test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)))
test(lambda t, o: t.to(o),
(torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
in_dims=(0, None))
test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
# also test some casting methods
test(lambda t: t.double(), (torch.rand(B0),))
test(lambda t: t.float(), (torch.rand(B0),))
test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
def test_unfold(self):
op = torch.Tensor.unfold
test = self._vmap_view_test
B0, B1, B2 = 3, 2, 5
test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
test(vmap(op, in_dims=(0, None, None, None)),
(torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
(torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
def test_unbind(self):
test = self._vmap_view_test
op = torch.unbind
B0, B1, B2 = 7, 11, 13
test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
test(op, (torch.rand(B0, 2, 0),))
test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
in_dims=(2, None))
test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
(torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
def test_view(self):
test = self._vmap_view_test
B0, B1, B2 = 7, 11, 13
op = torch.Tensor.view
# We should error out if the view would produce an incorrect result
with self.assertRaises(RuntimeError):
vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])