-
Notifications
You must be signed in to change notification settings - Fork 42
/
Copy pathlayer_norm.py
executable file
·351 lines (259 loc) · 12.6 KB
/
layer_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import torch
import torch.nn as nn
from e3nn import o3
from e3nn.o3 import Irreps
from e3nn.util.jit import compile_mode
# Reference:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/DGLPyTorch/DrugDiscovery/SE3Transformer/se3_transformer/model/layers/norm.py
# https://github.com/e3nn/e3nn/blob/main/e3nn/nn/_batchnorm.py
@compile_mode('unsupported')
class EquivariantLayerNorm(torch.nn.Module):
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, irreps_in, eps=1e-5):
super().__init__()
self.irreps_in = irreps_in
self.eps = eps
self.layer_norms = []
for idx, (mul, ir) in enumerate(self.irreps_in):
self.layer_norms.append(torch.nn.LayerNorm(mul, eps))
self.layer_norms = torch.nn.ModuleList(self.layer_norms)
#self.relu = torch.nn.ReLU()
def forward(self, f_in, **kwargs):
'''
Assume `f_in` is of shape [N, C].
'''
f_out = []
channel_idx = 0
N = f_in.shape[0]
for degree_idx, (mul, ir) in enumerate(self.irreps_in):
feat = f_in[:, channel_idx:(channel_idx+mul*ir.dim)]
feat = feat.reshape(N, mul, ir.dim)
norm = feat.norm(dim=-1).clamp(min=self.NORM_CLAMP)
new_norm = self.layer_norms[degree_idx](norm)
#if not ir.is_scalar():
# new_norm = self.relu(new_norm)
norm = norm.reshape(N, mul, 1)
new_norm = new_norm.reshape(N, mul, 1)
feat = feat * new_norm / norm
feat = feat.reshape(N, -1)
f_out.append(feat)
channel_idx += mul * ir.dim
f_out = torch.cat(f_out, dim=-1)
return f_out
def __repr__(self):
return '{}({}, eps={})'.format(self.__class__.__name__,
self.irreps_in, self.eps)
class EquivariantLayerNormV2(nn.Module):
def __init__(self, irreps, eps=1e-5, affine=True, normalization='component'):
super().__init__()
self.irreps = Irreps(irreps)
self.eps = eps
self.affine = affine
num_scalar = sum(mul for mul, ir in self.irreps if ir.l == 0 and ir.p == 1)
num_features = self.irreps.num_irreps
if affine:
self.affine_weight = nn.Parameter(torch.ones(num_features))
self.affine_bias = nn.Parameter(torch.zeros(num_scalar))
else:
self.register_parameter('affine_weight', None)
self.register_parameter('affine_bias', None)
assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"
self.normalization = normalization
def __repr__(self):
return f"{self.__class__.__name__}({self.irreps}, eps={self.eps})"
@torch.cuda.amp.autocast(enabled=False)
def forward(self, node_input, **kwargs):
# batch, *size, dim = node_input.shape # TODO: deal with batch
# node_input = node_input.reshape(batch, -1, dim) # [batch, sample, stacked features]
# node_input has shape [batch * nodes, dim], but with variable nr of nodes.
# the node_input batch slices this into separate graphs
dim = node_input.shape[-1]
fields = []
ix = 0
iw = 0
ib = 0
for mul, ir in self.irreps: # mul is the multiplicity (number of copies) of some irrep type (ir)
d = ir.dim
#field = node_input[:, ix: ix + mul * d] # [batch * sample, mul * repr]
field = node_input.narrow(1, ix, mul*d)
ix += mul * d
# [batch * sample, mul, repr]
field = field.reshape(-1, mul, d)
# For scalars first compute and subtract the mean
if ir.l == 0 and ir.p == 1:
# Compute the mean
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, mul, 1]]
# Subtract the mean
field = field - field_mean
# Then compute the rescaling factor (norm of each feature vector)
# Rescaling of the norms themselves based on the option "normalization"
if self.normalization == 'norm':
field_norm = field.pow(2).sum(-1) # [batch * sample, mul]
elif self.normalization == 'component':
field_norm = field.pow(2).mean(-1) # [batch * sample, mul]
else:
raise ValueError("Invalid normalization option {}".format(self.normalization))
field_norm = torch.mean(field_norm, dim=1, keepdim=True)
# Then apply the rescaling (divide by the sqrt of the squared_norm, i.e., divide by the norm
field_norm = (field_norm + self.eps).pow(-0.5) # [batch, mul]
if self.affine:
weight = self.affine_weight[None, iw: iw + mul] # [batch, mul]
iw += mul
field_norm = field_norm * weight # [batch, mul]
field = field * field_norm.reshape(-1, mul, 1) # [batch * sample, mul, repr]
if self.affine and d == 1 and ir.p == 1: # scalars
bias = self.affine_bias[ib: ib + mul] # [batch, mul]
ib += mul
field += bias.reshape(mul, 1) # [batch * sample, mul, repr]
# Save the result, to be stacked later with the rest
fields.append(field.reshape(-1, mul * d)) # [batch * sample, mul * repr]
if ix != dim:
fmt = "`ix` should have reached node_input.size(-1) ({}), but it ended at {}"
msg = fmt.format(dim, ix)
raise AssertionError(msg)
output = torch.cat(fields, dim=-1) # [batch * sample, stacked features]
return output
class EquivariantLayerNormV3(nn.Module):
'''
V2 + Centering for vectors of all degrees
'''
def __init__(self, irreps, eps=1e-5, affine=True, normalization='component'):
super().__init__()
self.irreps = Irreps(irreps)
self.eps = eps
self.affine = affine
num_scalar = sum(mul for mul, ir in self.irreps if ir.l == 0 and ir.p == 1)
num_features = self.irreps.num_irreps
if affine:
self.affine_weight = nn.Parameter(torch.ones(num_features))
self.affine_bias = nn.Parameter(torch.zeros(num_scalar))
else:
self.register_parameter('affine_weight', None)
self.register_parameter('affine_bias', None)
assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"
self.normalization = normalization
def __repr__(self):
return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps})"
#@torch.autocast(device_type='cuda', enabled=False)
def forward(self, node_input, **kwargs):
dim = node_input.shape[-1]
fields = []
ix = 0
iw = 0
ib = 0
for mul, ir in self.irreps:
d = ir.dim
field = node_input.narrow(1, ix, mul * d)
ix += mul * d
field = field.reshape(-1, mul, d) # [batch * sample, mul, repr]
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, 1, repr]
field = field - field_mean
if self.normalization == 'norm':
field_norm = field.pow(2).sum(-1) # [batch * sample, mul]
elif self.normalization == 'component':
field_norm = field.pow(2).mean(-1) # [batch * sample, mul]
field_norm = torch.mean(field_norm, dim=1, keepdim=True)
field_norm = (field_norm + self.eps).pow(-0.5) # [batch, mul]
if self.affine:
weight = self.affine_weight[None, iw: iw + mul] # [batch, mul]
iw += mul
field_norm = field_norm * weight # [batch, mul]
field = field * field_norm.reshape(-1, mul, 1) # [batch * sample, mul, repr]
if self.affine and d == 1 and ir.p == 1: # scalars
bias = self.affine_bias[ib: ib + mul] # [batch, mul]
ib += mul
field += bias.reshape(mul, 1) # [batch * sample, mul, repr]
# Save the result, to be stacked later with the rest
fields.append(field.reshape(-1, mul * d)) # [batch * sample, mul * repr]
if ix != dim:
fmt = "`ix` should have reached node_input.size(-1) ({}), but it ended at {}"
msg = fmt.format(dim, ix)
raise AssertionError(msg)
output = torch.cat(fields, dim=-1) # [batch * sample, stacked features]
return output
class EquivariantLayerNormV4(nn.Module):
'''
V3 + Learnable mean shift
'''
def __init__(self, irreps, eps=1e-5, affine=True, normalization='component'):
super().__init__()
self.irreps = Irreps(irreps)
self.eps = eps
self.affine = affine
num_scalar = sum(mul for mul, ir in self.irreps if ir.l == 0 and ir.p == 1)
num_features = self.irreps.num_irreps
mean_shift = []
for mul, ir in self.irreps:
if ir.l == 0 and ir.p == 1:
mean_shift.append(torch.ones(1, mul, 1))
else:
mean_shift.append(torch.zeros(1, mul, 1))
mean_shift = torch.cat(mean_shift, dim=1)
self.mean_shift = nn.Parameter(mean_shift)
if affine:
self.affine_weight = nn.Parameter(torch.ones(num_features))
self.affine_bias = nn.Parameter(torch.zeros(num_scalar))
else:
self.register_parameter('affine_weight', None)
self.register_parameter('affine_bias', None)
assert normalization in ['norm', 'component'], "normalization needs to be 'norm' or 'component'"
self.normalization = normalization
def __repr__(self):
return f"{self.__class__.__name__} ({self.irreps}, eps={self.eps})"
#@torch.autocast(device_type='cuda', enabled=False)
def forward(self, node_input, **kwargs):
dim = node_input.shape[-1]
fields = []
ix = 0
iw = 0
ib = 0
i_mean_shift = 0
for mul, ir in self.irreps:
d = ir.dim
field = node_input.narrow(1, ix, mul * d)
ix += mul * d
field = field.reshape(-1, mul, d) # [batch * sample, mul, repr]
field_mean = torch.mean(field, dim=1, keepdim=True) # [batch, 1, repr]
field_mean = field_mean.expand(-1, mul, -1)
mean_shift = self.mean_shift.narrow(1, i_mean_shift, mul)
field = field - field_mean * mean_shift
i_mean_shift += mul
if self.normalization == 'norm':
field_norm = field.pow(2).sum(-1) # [batch * sample, mul]
elif self.normalization == 'component':
field_norm = field.pow(2).mean(-1) # [batch * sample, mul]
field_norm = torch.mean(field_norm, dim=1, keepdim=True)
field_norm = (field_norm + self.eps).pow(-0.5) # [batch, mul]
if self.affine:
weight = self.affine_weight[None, iw: iw + mul] # [batch, mul]
iw += mul
field_norm = field_norm * weight # [batch, mul]
field = field * field_norm.reshape(-1, mul, 1) # [batch * sample, mul, repr]
if self.affine and d == 1 and ir.p == 1: # scalars
bias = self.affine_bias[ib: ib + mul] # [batch, mul]
ib += mul
field += bias.reshape(mul, 1) # [batch * sample, mul, repr]
# Save the result, to be stacked later with the rest
fields.append(field.reshape(-1, mul * d)) # [batch * sample, mul * repr]
if ix != dim:
fmt = "`ix` should have reached node_input.size(-1) ({}), but it ended at {}"
msg = fmt.format(dim, ix)
raise AssertionError(msg)
output = torch.cat(fields, dim=-1) # [batch * sample, stacked features]
return output
if __name__ == '__main__':
torch.manual_seed(10)
irreps_in = o3.Irreps('4x0e+2x1o+1x2e')
ln = EquivariantLayerNorm(irreps_in, eps=1e-5)
print(ln)
inputs = irreps_in.randn(10, -1)
ln.train()
outputs = ln(inputs)
# Check equivariant
rot = -o3.rand_matrix()
D = irreps_in.D_from_matrix(rot)
outputs_before = ln(inputs @ D.T)
outputs_after = ln(inputs) @ D.T
print(torch.max(torch.abs(outputs_after - outputs_before)))
ln2 = EquivariantLayerNormV4(irreps_in)
outputs2 = ln2(inputs)