-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
53 lines (45 loc) · 1.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from natten.functional import natten1dqk, natten1dav, natten2dqk, natten2dav
from e_natten import natten1d, natten2d
torch.manual_seed(0)
def test(og_qk, og_av, new_fn, input_shape, kernel_size):
q, k, v = torch.randn(input_shape).to('cuda')
q = q.clone().requires_grad_(True)
q.retain_grad()
k = k.clone().requires_grad_(True)
k.retain_grad()
v = v.clone().requires_grad_(True)
v.retain_grad()
# natten
s_1 = og_qk(q, k, kernel_size, 1)
s_1.retain_grad()
p_1 = torch.softmax(s_1, dim=-1)
p_1.retain_grad()
o_1 = og_av(p_1, v, kernel_size, 1)
# natten_triton
q_2 = q.detach().clone()
q_2.requires_grad = True
q_2.retain_grad()
k_2 = k.detach().clone()
k_2.requires_grad = True
k_2.retain_grad()
v_2 = v.detach().clone()
v_2.requires_grad = True
v_2.retain_grad()
o_2 = new_fn(q_2, k_2, v_2, kernel_size)
print('Forward pass:', torch.allclose(o_1, o_2, atol=1e-5))
# Check backward pass.
loss = torch.sum(o_1 ** 2)
loss.backward()
loss_2 = torch.sum(o_2 ** 2)
loss_2.backward()
print('Backward pass (Q):', torch.allclose(q.grad, q_2.grad, atol=1e-5))
print('Backward pass (K):', torch.allclose(k.grad, k_2.grad, atol=1e-5))
print('Backward pass (V):', torch.allclose(v.grad, v_2.grad, atol=1e-5))
if __name__ == '__main__':
for kernel_size in [3, 5, 7, 9, 11]:
print('# Kernel size: ', kernel_size)
print('## 1D attention')
test(natten1dqk, natten1dav, natten1d, (3, 2, 3, 16, 2), kernel_size)
print('## 2D attention')
test(natten2dqk, natten2dav, natten2d, (3, 2, 6, 16, 16, 2), kernel_size)