-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathriemannian_batch_norm.py
79 lines (63 loc) · 2.55 KB
/
riemannian_batch_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn
import math
from frechetmean.manifolds import Poincare, Lorentz
from frechetmean.frechet import frechet_mean
class RiemannianBatchNorm(nn.Module):
def __init__(self, dim, manifold):
super(RiemannianBatchNorm, self).__init__()
self.man = manifold
self.mean = nn.Parameter(self.man.zero_tan(self.man.dim_to_sh(dim)))
self.var = nn.Parameter(torch.tensor(1.0))
# statistics
self.running_mean = None
self.running_var = None
self.updates = 0
def forward(self, x, training=True, momentum=0.9):
on_manifold = self.man.exp0(self.mean)
if training:
# frechet mean, use iterative and don't batch (only need to compute one mean)
input_mean = frechet_mean(x, self.man)
input_var = self.man.frechet_variance(x, input_mean)
# transport input from current mean to learned mean
input_logm = self.man.transp(
input_mean,
on_manifold,
self.man.log(input_mean, x),
)
# re-scaling
input_logm = (self.var / (input_var + 1e-6)).sqrt() * input_logm
# project back
output = self.man.exp(on_manifold.unsqueeze(-2), input_logm)
self.updates += 1
if self.running_mean is None:
self.running_mean = input_mean
else:
self.running_mean = self.man.exp(
self.running_mean,
(1 - momentum) * self.man.log(self.running_mean, input_mean)
)
if self.running_var is None:
self.running_var = input_var
else:
self.running_var = (
1 - 1 / self.updates
) * self.running_var + input_var / self.updates
else:
if self.updates == 0:
raise ValueError("must run training at least once")
input_mean = frechet_mean(x, self.man)
input_var = self.man.frechet_variance(x, input_mean)
input_logm = self.man.transp(
input_mean,
self.running_mean,
self.man.log(input_mean, x),
)
assert not torch.any(torch.isnan(input_logm))
# re-scaling
input_logm = (
self.running_var / (x.shape[0] / (x.shape[0] - 1) * input_var + 1e-6)
).sqrt() * input_logm
# project back
output = self.man.exp(on_manifold, input_logm)
return output