forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_foreach.py
102 lines (83 loc) · 3.68 KB
/
test_foreach.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.cuda
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
class TestForeach(TestCase):
@dtypes(*torch.testing.get_all_dtypes())
def test_add_scalar_with_same_size_tensors(self, device, dtype):
N = 20
H = 20
W = 20
tensors = []
for _ in range(N):
tensors.append(torch.zeros(H, W, device=device, dtype=dtype))
res = torch._foreach_add(tensors, 1)
for t in res:
if dtype == torch.bool:
dtype = torch.int64
self.assertEqual(t, torch.ones(H, W, device=device, dtype=dtype))
@dtypes(*torch.testing.get_all_dtypes())
def test_add_scalar_with_different_size_tensors(self, device, dtype):
N = 20
H = 20
W = 20
tensors = []
size_change = 0
for _ in range(N):
tensors.append(torch.zeros(H + size_change, W + size_change, device=device, dtype=dtype))
size_change += 1
res = torch._foreach_add(tensors, 1)
size_change = 0
for t in res:
if dtype == torch.bool:
dtype = torch.int64
self.assertEqual(t, torch.ones(H + size_change, W + size_change, device=device, dtype=dtype))
size_change += 1
@dtypes(*torch.testing.get_all_dtypes())
def test_add_scalar_with_empty_list(self, device, dtype):
tensors = []
with self.assertRaises(RuntimeError):
torch._foreach_add(tensors, 1)
@dtypes(*torch.testing.get_all_dtypes())
def test_add_scalar_with_overlapping_tensors(self, device, dtype):
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)]
if dtype == torch.bool:
expected[0] = expected[0].to(torch.int64).add(1)
res = torch._foreach_add(tensors, 1)
self.assertEqual(res, expected)
def test_add_scalar_with_different_tensor_dtypes(self, device):
tensors = [torch.tensor([1], dtype=torch.float, device=device),
torch.tensor([1], dtype=torch.int, device=device)]
expected = [torch.tensor([2], dtype=torch.float, device=device),
torch.tensor([2], dtype=torch.int, device=device)]
res = torch._foreach_add(tensors, 1)
self.assertEqual(res, expected)
def test_add_scalar_with_different_scalar_type(self, device):
# int tensor with float scalar
# should go 'slow' route
scalar = 1.1
tensors = [torch.tensor([1], dtype=torch.int, device=device)]
res = torch._foreach_add(tensors, scalar)
self.assertEqual(res, [torch.tensor([2.1], device=device)])
# float tensor with int scalar
# should go 'fast' route
scalar = 1
tensors = [torch.tensor([1.1], device=device)]
res = torch._foreach_add(tensors, scalar)
self.assertEqual(res, [torch.tensor([2.1], device=device)])
# bool tensor with int scalar
# should go 'slow' route
scalar = 1
tensors = [torch.tensor([False], device=device)]
res = torch._foreach_add(tensors, scalar)
self.assertEqual(res, [torch.tensor([1], device=device)])
# bool tensor with float scalar
# should go 'slow' route
scalar = 1.1
tensors = [torch.tensor([False], device=device)]
res = torch._foreach_add(tensors, scalar)
self.assertEqual(res, [torch.tensor([1.1], device=device)])
instantiate_device_type_tests(TestForeach, globals())
if __name__ == '__main__':
run_tests()