-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathhier_torch.py
1877 lines (1590 loc) · 75 KB
/
hier_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from functools import partial
from multiprocessing import reduction
from typing import Callable, List, Optional, Sequence, Tuple
from absl import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import hier
import metrics
def flat_log_softmax(
tree: hier.Hierarchy,
scores: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Returns log of flat-softmax likelihood for all nodes."""
assert dim in (-1, scores.ndim - 1)
logp_leaf = F.log_softmax(scores, dim=-1)
# The value is_ancestor[i, j] is true if node i is an ancestor of node j.
is_ancestor = tree.ancestor_mask(strict=False)
is_leaf = tree.leaf_mask()
# The value is_ancestor_leaf[i, k] is true if node i is an ancestor of leaf k.
is_ancestor_leaf = is_ancestor[:, is_leaf]
# Obtain logp for all leaf descendants, -inf for other nodes.
is_ancestor_leaf = torch.from_numpy(is_ancestor_leaf).to(scores.device)
logp_descendants = torch.where(
is_ancestor_leaf,
logp_leaf.unsqueeze(-2),
torch.tensor(-torch.inf, device=scores.device))
return torch.logsumexp(logp_descendants, dim=-1)
class FlatLogSoftmax(nn.Module):
"""Implements flat_log_softmax as an object. Avoids re-computation."""
def __init__(self, tree):
super().__init__()
# The value is_ancestor[i, j] is true if node i is an ancestor of node j.
is_ancestor = tree.ancestor_mask(strict=False)
# The value is_ancestor_leaf[i, k] is true if node i is an ancestor of leaf k.
is_ancestor_leaf = is_ancestor[:, tree.leaf_mask()]
# TODO: May be important to avoid copying this to the device.
# However, overriding _apply() will change the dtype.
self.is_ancestor_leaf = torch.from_numpy(is_ancestor_leaf)
def _apply(self, fn):
super()._apply(fn)
self.is_ancestor_leaf = fn(self.is_ancestor_leaf)
return self
def forward(self, scores: torch.Tensor) -> torch.Tensor:
logp_leaf = F.log_softmax(scores, dim=-1)
# Obtain logp for leaf descendants, -inf for other nodes.
# TODO: This is hacky. Would prefer not to mutate object here.
logp_descendants = torch.where(
self.is_ancestor_leaf,
logp_leaf.unsqueeze(-2),
torch.tensor(-torch.inf, device=scores.device))
return torch.logsumexp(logp_descendants, dim=-1)
class FlatSoftmaxNLL(nn.Module):
"""Like cross_entropy() but supports internal labels."""
def __init__(self, tree, with_leaf_targets: bool = False, reduction: str = 'mean'):
super().__init__()
assert reduction in ('mean', 'none', None)
if with_leaf_targets:
raise ValueError('use F.cross_entropy() instead!')
# The value is_ancestor[i, j] is true if node i is an ancestor of node j.
is_ancestor = tree.ancestor_mask(strict=False)
leaf_masks = is_ancestor[:, tree.leaf_mask()]
self.leaf_masks = torch.from_numpy(leaf_masks)
self.reduction = reduction
def _apply(self, fn):
super()._apply(fn)
self.leaf_masks = fn(self.leaf_masks)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
logp_leaf = F.log_softmax(scores, dim=-1)
# Obtain logp for leaf descendants, -inf for other nodes.
label_leaf_mask = self.leaf_masks[labels, :]
inf = torch.tensor(torch.inf, device=scores.device)
logp_descendants = torch.where(label_leaf_mask, logp_leaf, -inf)
logp_label = torch.logsumexp(logp_descendants, dim=-1)
loss = -logp_label
if self.reduction == 'mean':
return torch.mean(loss)
else:
return loss
def flat_bertinetto_hxe(
tree: hier.Hierarchy,
scores: torch.Tensor,
labels: torch.Tensor,
alpha: float = 0.0,
dim: int = -1) -> torch.Tensor:
"""Returns the HXE loss from "Making Better Mistakes".
The target is an index of a node in the tree.
"""
device = scores.device
# Take log-sum-exp of leaf descendants.
parent = torch.from_numpy(tree.parents(root_loop=True)).to(device)
parent_depth = torch.from_numpy(tree.depths() - 1).to(device)
weight = torch.exp(-alpha * parent_depth)
leaf_subset = torch.from_numpy(tree.leaf_subset()).to(device)
scores_full = (
torch.full((*scores.shape[:-1], tree.num_nodes()), -torch.inf, dtype=torch.float32)
.index_copy(-1, leaf_subset, scores))
logsumexp = descendant_logsumexp(tree, scores_full)
log_cond_p = logsumexp - logsumexp[..., parent]
# Take weighted sum over ancestors.
# Weight each conditional likelihood by exp(-alpha * parent_depth).
# Note that log_cond_p of root is always zero.
assert dim in (-1, scores.ndim - 1)
weighted_cond_nll = weight * -log_cond_p
weighted_nll = sum_ancestors(tree, weighted_cond_nll, dim=dim, strict=False)
assert labels.ndim == scores.ndim - 1
label_nll = torch.gather(weighted_nll, dim, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(label_nll)
class FlatBertinettoHXE(nn.Module):
"""Implements flat_bertinetto_hxe as an object. Avoids re-computation."""
def __init__(self, tree: hier.Hierarchy, alpha: float, with_leaf_targets: bool, reduction: str = 'mean'):
super().__init__()
assert reduction in ('mean', 'none', None)
paths = tree.paths_padded(method='constant', pad_value=-1, exclude_root=False)
is_ancestor = tree.ancestor_mask(strict=False)
leaf_masks = is_ancestor[:, tree.leaf_mask()]
if with_leaf_targets:
label_order = torch.from_numpy(tree.leaf_subset())
paths = paths[label_order, :]
self.alpha = alpha
self.reduction = reduction
self.max_depth = np.max(tree.depths())
self.paths = torch.from_numpy(paths)
self.leaf_masks = torch.from_numpy(leaf_masks)
def _apply(self, fn):
super()._apply(fn)
self.paths = fn(self.paths)
self.leaf_masks = fn(self.leaf_masks)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
assert labels.ndim == scores.ndim - 1
device = scores.device
# TODO: Can use for loop if this uses too much memory (n d with d = log n).
# Get leaf-descendant mask for all ancestors.
label_path = self.paths[labels]
path_valid = (label_path >= 0)
label_ancestor_leaf_masks = torch.logical_and(
path_valid.unsqueeze(-1),
self.leaf_masks[torch.where(path_valid, label_path, 0), :])
inf = torch.tensor(torch.inf, device=device)
lse_ancestor = torch.logsumexp(
torch.where(label_ancestor_leaf_masks, scores.unsqueeze(-2), -inf), dim=-1)
# lse(parent) - lse(child)
cond_nll = -torch.diff(lse_ancestor, dim=-1)
cond_nll = torch.where(path_valid[:, 1:], cond_nll, torch.tensor(0.0, device=device))
weight = torch.exp(-self.alpha * torch.arange(self.max_depth, device=device))
weighted_nll = cond_nll * weight
loss = torch.sum(weighted_nll, dim=-1)
if self.reduction == 'mean':
return torch.mean(loss)
else:
return loss
def hier_softmax_nll(
tree: hier.Hierarchy,
scores: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""Returns cross-entropy loss for YOLO-style conditional softmax.
Labels are node indices.
"""
log_prob = hier_log_softmax(tree, scores, dim=-1)
assert labels.ndim == scores.ndim - 1
nll = -torch.gather(log_prob, -1, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(nll)
def hier_softmax_nll_with_leaf(
tree: hier.Hierarchy,
scores: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""Returns cross-entropy loss for YOLO-style conditional softmax.
Labels are leaf indices.
"""
device = scores.device
leaf_subset = torch.from_numpy(tree.leaf_subset()).to(device)
# TODO: Could make this faster by only computing likelihood for targets.
log_prob = hier_log_softmax(tree, scores, dim=-1)
leaf_log_prob = torch.index_select(log_prob, -1, leaf_subset)
assert labels.ndim == scores.ndim - 1
nll = -torch.gather(leaf_log_prob, -1, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(nll)
class HierSoftmaxNLL(nn.Module):
"""Implements hier_softmax_nll as an object. Avoids re-computation."""
def __init__(
self,
tree: hier.Hierarchy,
with_leaf_targets: bool = False):
super().__init__()
if with_leaf_targets:
self.label_order = torch.from_numpy(tree.leaf_subset())
else:
self.label_order = None
self.hier_log_softmax = HierLogSoftmax(tree)
def _apply(self, fn):
super()._apply(fn)
self.label_order = _apply_to_maybe(fn, self.label_order)
self.hier_log_softmax = self.hier_log_softmax._apply(fn)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
# TODO: Could make this faster by only computing likelihood for targets.
log_prob = self.hier_log_softmax(scores, dim=-1)
if self.label_order is not None:
log_prob = torch.index_select(log_prob, -1, self.label_order)
nll = -torch.gather(log_prob, -1, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(nll)
class HierSoftmaxCrossEntropy(nn.Module):
"""Implements cross-entropy for YOLO-style conditional softmax. Avoids re-computation.
Supports integer label targets or distribution targets.
"""
def __init__(
self,
tree: hier.Hierarchy,
with_leaf_targets: bool = False,
label_smoothing: float = 0.0,
node_weight: Optional[torch.Tensor] = None):
super().__init__()
self.label_smoothing = label_smoothing
if with_leaf_targets:
self.label_order = torch.from_numpy(tree.leaf_subset())
self.num_labels = len(self.label_order)
else:
self.label_order = None
self.num_labels = tree.num_nodes()
self.hier_cond_log_softmax = HierCondLogSoftmax(tree)
self.sum_label_descendants = SumDescendants(tree, subset=self.label_order)
self.prior = torch.from_numpy(hier.uniform_leaf(tree))
self.node_weight = node_weight
def _apply(self, fn):
super()._apply(fn)
self.label_order = _apply_to_maybe(fn, self.label_order)
self.hier_cond_log_softmax = self.hier_cond_log_softmax._apply(fn)
self.sum_label_descendants = self.sum_label_descendants._apply(fn)
self.prior = fn(self.prior)
self.node_weight = _apply_to_maybe(fn, self.node_weight)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert labels.ndim in [scores.ndim, scores.ndim - 1]
assert dim in (-1, scores.ndim - 1)
# Convert labels to one-hot if they are not.
if labels.ndim < scores.ndim:
labels = F.one_hot(labels, self.num_labels)
labels = labels.type(torch.get_default_dtype())
q = self.sum_label_descendants(labels)
if self.label_smoothing:
q = (1 - self.label_smoothing) * q + self.label_smoothing * self.prior
log_cond_p = self.hier_cond_log_softmax(scores, dim=-1)
xent = q * -log_cond_p
if self.node_weight is not None:
xent = xent * self.node_weight
xent = torch.sum(xent, dim=-1)
return torch.mean(xent)
def hier_cond_log_softmax(
tree: hier.Hierarchy,
scores: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Returns log-likelihood of each node given its parent."""
# Split scores into softmax for each internal node over its children.
# Convert from [s[0], s[1], ..., s[n-1]]
# to [[s[0], ..., s[k-1], -inf, -inf, ...],
# ...
# [..., s[n-1], -inf, -inf, ...]].
# Use index_copy with flat_index, then reshape and compute log_softmax.
# Then re-flatten and use index_select with flat_index.
# This is faster than using torch.split() and map(log_softmax, ...).
assert dim == -1 or dim == scores.ndim - 1
num_nodes = tree.num_nodes()
num_internal = tree.num_internal_nodes()
node_to_children = tree.children()
cond_children = [node_to_children[x] for x in tree.internal_subset()]
cond_num_children = list(map(len, cond_children))
max_num_children = max(cond_num_children)
# TODO: Use _split_and_pad?
row_index = np.concatenate([np.full(n, i) for i, n in enumerate(cond_num_children)])
col_index = np.concatenate([np.arange(n) for n in cond_num_children])
flat_index = row_index * max_num_children + col_index
child_index = np.concatenate(cond_children)
device = scores.device
flat_index = torch.from_numpy(flat_index).to(device)
child_index = torch.from_numpy(child_index).to(device)
input_shape = list(scores.shape)
flat_shape = [*input_shape[:-1], num_internal * max_num_children]
# Pad with -inf for log_softmax.
# flat[..., flat_index] = scores
flat = torch.full(flat_shape, -torch.inf, device=device).index_copy(
-1, flat_index, scores)
split_shape = [*input_shape[:-1], num_internal, max_num_children]
child_scores = flat.reshape(split_shape)
child_log_p = F.log_softmax(child_scores, dim=-1)
child_log_p = child_log_p.reshape(flat_shape)
output_shape = [*input_shape[:-1], num_nodes]
# log_cond_p[..., child_index] = child_log_p[..., flat_index]
log_cond_p = torch.zeros(output_shape, device=device).index_copy(
-1, child_index, child_log_p.index_select(-1, flat_index))
return log_cond_p
def hier_log_softmax(
tree: hier.Hierarchy,
scores: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Returns log-likelihood for conditional softmax."""
# Finally, take sum over ancestor conditionals to obtain likelihoods.
assert dim in (-1, scores.ndim - 1)
log_cond_p = hier_cond_log_softmax(tree, scores, dim=dim)
# TODO: Use functional form here?
device = scores.device
sum_ancestors_fn = SumAncestors(tree, exclude_root=False).to(device)
return sum_ancestors_fn(log_cond_p, dim=-1)
class HierCondLogSoftmax(nn.Module):
"""Implements hier_cond_log_softmax as an object. Avoids re-computation."""
def __init__(self, tree: hier.Hierarchy):
super().__init__()
num_nodes = tree.num_nodes()
num_internal = tree.num_internal_nodes()
node_to_children = tree.children()
cond_children = [node_to_children[x] for x in tree.internal_subset()]
cond_num_children = list(map(len, cond_children))
max_num_children = max(cond_num_children)
# TODO: Use _split_and_pad?
row_index = np.concatenate([np.full(n, i) for i, n in enumerate(cond_num_children)])
col_index = np.concatenate([np.arange(n) for n in cond_num_children])
flat_index = torch.from_numpy(row_index * max_num_children + col_index)
child_index = torch.from_numpy(np.concatenate(cond_children))
self.num_nodes = num_nodes
self.num_internal = num_internal
self.max_num_children = max_num_children
self.flat_index = flat_index
self.child_index = child_index
def _apply(self, fn):
super()._apply(fn)
self.flat_index = fn(self.flat_index)
self.child_index = fn(self.child_index)
return self
def forward(self, scores: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim in (-1, scores.ndim - 1)
device = scores.device
input_shape = list(scores.shape)
flat_shape = [*input_shape[:-1], self.num_internal * self.max_num_children]
# Pad with -inf for log_softmax.
# flat[..., flat_index] = scores
flat = torch.full(flat_shape, -torch.inf, device=device).index_copy(
-1, self.flat_index, scores)
split_shape = [*input_shape[:-1], self.num_internal, self.max_num_children]
child_scores = flat.reshape(split_shape)
child_log_p = F.log_softmax(child_scores, dim=-1)
child_log_p = child_log_p.reshape(flat_shape)
output_shape = [*input_shape[:-1], self.num_nodes]
# log_cond_p[..., child_index] = child_log_p[..., flat_index]
log_cond_p = torch.zeros(output_shape, device=device).index_copy(
-1, self.child_index, child_log_p.index_select(-1, self.flat_index))
return log_cond_p
class HierLogSoftmax(nn.Module):
"""Implements hier_log_softmax as an object. Avoids re-computation."""
def __init__(self, tree: hier.Hierarchy):
super().__init__()
self.cond_log_softmax = HierCondLogSoftmax(tree)
self.sum_ancestors_fn = SumAncestors(tree, exclude_root=False)
def _apply(self, fn):
super()._apply(fn)
self.cond_log_softmax = self.cond_log_softmax._apply(fn)
self.sum_ancestors_fn = self.sum_ancestors_fn._apply(fn)
return self
def forward(self, scores: torch.Tensor, dim: int = -1) -> torch.Tensor:
log_cond_p = self.cond_log_softmax(scores, dim=dim)
return self.sum_ancestors_fn(log_cond_p, dim=dim)
# class HierSoftmaxNLLWithInactive(nn.Module):
# """Hierarchical softmax with loss for inactive nodes."""
#
# def __init__(
# self,
# tree: hier.Hierarchy,
# with_leaf_targets: bool = False):
# super().__init__()
# if with_leaf_targets:
# self.label_order = torch.from_numpy(tree.leaf_subset())
# else:
# self.label_order = None
# self.hier_log_softmax = HierLogSoftmax(tree)
#
# def _apply(self, fn):
# super()._apply(fn)
# # Do not apply fn to indices because it might convert dtype.
# self.label_order = _apply_to_maybe(fn, self.label_order)
# self.hier_log_softmax = self.hier_log_softmax._apply(fn)
# return self
#
# def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
# # TODO: Could make this faster by only computing likelihood for targets.
# device = scores.device
#
# # Log-sum-exp for each internal node.
# self.logsumexp = ()
#
# log_prob = self.hier_log_softmax(scores, dim=-1)
# if self.label_order is not None:
# log_prob = torch.index_select(log_prob, -1, self.label_order.to(device))
# nll = -torch.gather(log_prob, -1, labels.unsqueeze(-1)).squeeze(-1)
# return torch.mean(nll)
def leaf_put(
tree: hier.Hierarchy,
values: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Embeds leaf values in a tree. Internal nodes are set to zero."""
shape = list(values.shape)
shape[dim] = tree.num_nodes()
node_values = torch.zeros(shape, device=values.device)
return leaf_add(tree, node_values, values, dim=dim)
def leaf_add(
tree: hier.Hierarchy,
node_values: torch.Tensor,
leaf_values: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Adds values for leaf nodes to values for all nodes."""
device = node_values.device
leaf_index = torch.from_numpy(tree.leaf_subset()).to(device)
return node_values.index_add(dim, leaf_index, leaf_values)
def sum_leaf_descendants(
tree: hier.Hierarchy,
values: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Computes sum over leaf descendants for each node."""
# The value is_ancestor[i, j] is true if i is an ancestor of j.
is_ancestor = tree.ancestor_mask()
leaf_is_descendant = is_ancestor[:, tree.leaf_mask()].T
matrix = torch.from_numpy(leaf_is_descendant)
matrix = matrix.to(device=values.device, dtype=torch.get_default_dtype())
# TODO: Re-order dimensions to make this work with dim != -1.
assert dim in (-1, values.ndim - 1)
return torch.tensordot(values, matrix, dims=1)
def sum_descendants(
tree: hier.Hierarchy,
values: torch.Tensor,
dim: int = -1,
strict: bool = False) -> torch.Tensor:
"""Computes sum over all descendants for each node."""
# The value is_ancestor[i, j] is true if i is an ancestor of j.
is_ancestor = tree.ancestor_mask(strict=strict)
matrix = (torch.from_numpy(is_ancestor.T)
.to(device=values.device, dtype=torch.get_default_dtype()))
# TODO: Re-order dimensions to make this work with dim != -1.
assert dim in (-1, values.ndim - 1)
return torch.tensordot(values, matrix, dims=1)
def sum_ancestors(
tree: hier.Hierarchy,
values: torch.Tensor,
dim: int = -1,
strict: bool = False) -> torch.Tensor:
"""Computes sum over ancestors of each node."""
# The value is_ancestor[i, j] is true if i is an ancestor of j.
is_ancestor = tree.ancestor_mask(strict=strict)
matrix = (torch.from_numpy(is_ancestor)
.to(device=values.device, dtype=torch.get_default_dtype()))
# TODO: Re-order dimensions to make this work with dim != -1.
assert dim in (-1, values.ndim - 1)
return torch.tensordot(values, matrix, dims=1)
def sum_leaf_ancestors(
tree: hier.Hierarchy,
values: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""Computes sum over ancestors of each leaf."""
# The value is_ancestor[i, j] is true if i is an ancestor of j.
is_ancestor = tree.ancestor_mask()
is_ancestor_for_leaf = is_ancestor[:, tree.leaf_mask()]
matrix = (torch.from_numpy(is_ancestor_for_leaf)
.to(device=values.device, dtype=torch.get_default_dtype()))
# TODO: Re-order dimensions to make this work with dim != -1.
assert dim in (-1, values.ndim - 1)
return torch.tensordot(values, matrix, dims=1)
class Sum(nn.Module):
"""Implements sum_xxx as an object. Avoids re-computation."""
def __init__(
self,
tree: hier.Hierarchy,
transpose: bool,
subset: Optional[np.ndarray] = None,
# leaf_only: bool = False,
exclude_root: bool = False,
strict: bool = False):
super().__init__()
# The value matrix[i, j] is true if i is an ancestor of j.
# Take transpose for sum over descendants.
matrix = tree.ancestor_mask(strict=strict)
if subset is not None:
matrix = matrix[:, subset]
if exclude_root:
matrix = matrix[1:, :]
if transpose:
matrix = matrix.T
matrix = torch.from_numpy(matrix).type(torch.get_default_dtype())
self.matrix = matrix
def _apply(self, fn):
super()._apply(fn)
self.matrix = fn(self.matrix)
return self
def forward(self, values: torch.Tensor, dim: int = -1) -> torch.Tensor:
# TODO: Re-order dimensions to make this work with dim != -1.
assert dim in (-1, values.ndim - 1)
return torch.tensordot(values, self.matrix, dims=1)
SumAncestors = partial(Sum, transpose=False)
SumDescendants = partial(Sum, transpose=True)
# SumLeafAncestors = partial(Sum, transpose=True, leaf_only=True)
# SumLeafDescendants = partial(Sum, transpose=True, leaf_only=True)
def SumLeafAncestors(
tree: hier.Hierarchy,
**kwargs):
return SumAncestors(tree, subset=tree.leaf_mask(), **kwargs)
def SumLeafDescendants(
tree: hier.Hierarchy,
**kwargs):
return SumDescendants(tree, subset=tree.leaf_mask(), **kwargs)
class MultiLabelNLL(nn.Module):
def __init__(
self,
tree: hier.Hierarchy,
with_leaf_targets: bool = False,
include_root: bool = False,
node_weight: Optional[torch.Tensor] = None):
super().__init__()
# The boolean array binary_targets[i, :] indicates whether
# each node is an ancestor of node i (including i itself).
# The root node is excluded since it is always positive.
binary_targets = tree.ancestor_mask(strict=False).T
if not include_root:
binary_targets = binary_targets[:, 1:]
if with_leaf_targets:
binary_targets = binary_targets[tree.leaf_subset(), :]
if node_weight is not None:
if not include_root:
node_weight = node_weight[:, 1:]
dtype = torch.get_default_dtype()
self.binary_targets = torch.from_numpy(binary_targets).type(dtype)
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
self.node_weight = _apply_to_maybe(torch.from_numpy, node_weight)
def _apply(self, fn):
super()._apply(fn)
self.binary_targets = fn(self.binary_targets)
self.bce_loss = self.bce_loss._apply(fn)
self.node_weight = _apply_to_maybe(fn, self.node_weight)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor, dim: int = -1) -> torch.Tensor:
targets = torch.index_select(self.binary_targets, 0, labels)
node_loss = self.bce_loss(scores, targets)
# Reduce over classes.
if self.node_weight is not None:
node_loss = node_loss * self.node_weight
loss = torch.sum(node_loss, dim=dim)
# Take mean over examples.
return torch.mean(loss)
class MultiLabelFocalLoss(nn.Module):
def __init__(
self,
tree: hier.Hierarchy,
alpha: float,
gamma: float,
with_leaf_targets: bool = False,
include_root: bool = False,
weighting_strategy: str = 'none'):
"""
It may be useful to have both alpha and weighting_strategy.
With alpha = 0.5, node_weight = 1, have huge initial loss with most for negative.
With alpha = 0.9, node_weight = 1, have log(n) nodes with most weight per example.
With alpha = 0.5, node_weight = 1/size, have equal weight for all nodes on average.
With alpha = 0.9, node_weight = 1/size, equal weight is given to nodes at all levels.
"""
super().__init__()
# The boolean array binary_targets[i, :] indicates whether
# each node is an ancestor of node i (including i itself).
# The root node is excluded since it is always positive.
binary_targets = tree.ancestor_mask(strict=False).T
if not include_root:
binary_targets = binary_targets[:, 1:]
if with_leaf_targets:
binary_targets = binary_targets[tree.leaf_subset(), :]
node_weight = make_loss_weights(weighting_strategy, tree, exclude_root=True)
dtype = torch.get_default_dtype()
self.binary_targets = torch.from_numpy(binary_targets).type(dtype)
self.alpha = alpha
self.gamma = gamma
self.node_weight = _apply_to_maybe(torch.from_numpy, node_weight)
def _apply(self, fn):
super()._apply(fn)
self.binary_targets = fn(self.binary_targets)
self.node_weight = _apply_to_maybe(fn, self.node_weight)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor, dim: int = -1) -> torch.Tensor:
targets = torch.index_select(self.binary_targets, 0, labels)
binary_loss = torchvision.ops.sigmoid_focal_loss(
scores, targets, alpha=self.alpha, gamma=self.gamma, reduction='none')
# Reduce over classes.
if self.node_weight is not None:
binary_loss = binary_loss * self.node_weight
loss = torch.sum(binary_loss, dim=dim)
# Take mean over examples.
return torch.mean(loss)
def multilabel_log_likelihood(
scores: torch.Tensor,
dim: int = -1,
insert_root: bool = False,
replace_root: bool = False,
temperature: Optional[float] = None) -> torch.Tensor:
assert not (insert_root and replace_root)
assert dim in (-1, scores.ndim - 1)
device = scores.device
if temperature:
scores = scores / temperature
logp = F.logsigmoid(scores)
if insert_root:
zero = torch.zeros((*scores.shape[:-1], 1), device=device)
logp = torch.cat([zero, logp], dim=-1)
elif replace_root:
zero = torch.zeros((*scores.shape[:-1], 1), device=device)
logp = logp.index_copy(-1, torch.tensor([0], device=device), zero)
return logp
class HierLogSigmoid(nn.Module):
def __init__(self, tree: hier.Hierarchy):
super().__init__()
self.log_sigmoid = torch.nn.LogSigmoid()
self.sum_ancestors = SumAncestors(tree, exclude_root=True)
def _apply(self, fn):
super()._apply(fn)
self.log_sigmoid = self.log_sigmoid._apply(fn)
self.sum_ancestors = self.sum_ancestors._apply(fn)
return self
def forward(self, scores: torch.Tensor, dim: int = -1) -> torch.Tensor:
log_cond = self.log_sigmoid(scores)
return self.sum_ancestors(log_cond, dim=dim)
# class HierSigmoidNLLLoss(nn.Module):
#
# def __init__(self, tree: hier.Hierarchy, with_leaf_targets: bool = False):
# super().__init__()
# self.log_sigmoid = torch.nn.LogSigmoid()
# subset = tree.leaf_mask() if with_leaf_targets else None
# self.sum_ancestors = SumAncestors(tree, exclude_root=True, strict=True)
# self.parents = torch.from_numpy(tree.parents()[1:])
#
# def _apply(self, fn):
# super()._apply(fn)
# self.log_sigmoid = self.log_sigmoid._apply(fn)
# self.sum_ancestors = self.sum_ancestors._apply(fn)
# return self
#
# def forward(self, scores: torch.Tensor, labels: torch.Tensor, dim: int = -1) -> torch.Tensor:
# logp_node_given_parent = self.log_sigmoid(scores)
# logp_not_node_given_parent = self.log_sigmoid(-scores)
# logp_parent = self.sum_ancestors(logp_node_given_parent, dim=dim)
# logp_node = logp_parent + logp_node_given_parent # TODO: Insert root node.
# logp_parent_and_not_node = logp_node[self.parents] + logp_not_node_given_parent
# # Need logsumexp_ancestors!
# logp_not_node = self.sum_ancestors(torch.exp(logp_parent_and_not_node), dim=dim)
def hier_softmax_cross_entropy(
tree: hier.Hierarchy,
scores: torch.Tensor,
q: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""The target distribution q should be hierarchical."""
# Get conditional likelihoods for each node given its parent.
log_cond_p = hier_cond_log_softmax(tree, scores, dim=dim)
xent = torch.sum(q * -log_cond_p, dim=dim)
return torch.mean(xent)
def bertinetto_hxe(
tree: hier.Hierarchy,
scores: torch.Tensor,
labels: torch.Tensor,
alpha: float = 0.0,
dim: int = -1) -> torch.Tensor:
"""The target is an index of a node in the tree."""
device = scores.device
# Get conditional likelihoods for each node given its parent.
log_cond_p = hier_cond_log_softmax(tree, scores, dim=dim)
# Take weighted sum over ancestors.
# Weight each conditional likelihood by exp(-alpha * parent_depth).
# Note that log_cond_p of root is always zero.
parent_depth = torch.from_numpy(tree.depths() - 1)
weight = torch.exp(-alpha * parent_depth).to(device)
assert dim in (-1, scores.ndim - 1)
weighted_cond_nll = -weight * log_cond_p
weighted_nll = sum_ancestors(tree, weighted_cond_nll, dim=dim, strict=False)
assert labels.ndim == scores.ndim - 1
label_nll = torch.gather(weighted_nll, dim, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(label_nll)
class BertinettoHXE(nn.Module):
"""Avoids re-computation in bertinetto_hxe()."""
def __init__(self,
tree: hier.Hierarchy,
alpha: float = 0.,
with_leaf_targets: bool = False):
super().__init__()
self.hier_cond_log_softmax_fn = HierCondLogSoftmax(tree)
parent_depth = torch.from_numpy(tree.depths() - 1)
self.weight = torch.exp(-alpha * parent_depth)
if with_leaf_targets:
self.label_order = torch.from_numpy(tree.leaf_subset())
else:
self.label_order = None
self.sum_ancestors_fn = SumAncestors(tree, strict=False)
def _apply(self, fn):
super()._apply(fn)
self.weight = fn(self.weight)
self.hier_cond_log_softmax_fn = self.hier_cond_log_softmax_fn._apply(fn)
self.label_order = _apply_to_maybe(fn, self.label_order)
self.sum_ancestors_fn = self.sum_ancestors_fn._apply(fn)
return self
def forward(self,
scores: torch.Tensor,
labels: torch.Tensor,
dim: int = -1) -> torch.Tensor:
log_cond_p = self.hier_cond_log_softmax_fn(scores, dim=dim)
assert dim in (-1, scores.ndim - 1)
# Take weighted sum over ancestors.
# Weight each conditional likelihood by exp(-alpha * parent_depth).
# Note that log_cond_p of root is always zero.
weighted_cond_nll = -self.weight * log_cond_p
weighted_nll = self.sum_ancestors_fn(weighted_cond_nll, dim=dim)
assert labels.ndim == scores.ndim - 1
if self.label_order is not None:
weighted_nll = torch.index_select(weighted_nll, dim, self.label_order)
label_nll = torch.gather(weighted_nll, dim, labels.unsqueeze(-1)).squeeze(-1)
return torch.mean(label_nll)
def levelwise_softmax_nll(
tree: hier.Hierarchy,
scores: torch.Tensor,
labels: torch.Tensor,
with_leaf_targets: bool) -> torch.Tensor:
level_nodes = hier.level_nodes(tree, extend=True)
level_sizes = tuple(map(len, level_nodes))
num_levels = len(level_nodes)
# Construct map from node to index within softmax at each level.
node_to_level_target = np.full([num_levels, tree.num_nodes()], -1, dtype=int)
for i in range(num_levels):
node_to_level_target[i, level_nodes[i]] = np.arange(level_sizes[i])
# Get label set.
if with_leaf_targets:
label_to_node = tree.leaf_subset()
else:
label_to_node = np.arange(tree.num_nodes())
# Find mapping from label to index within softmax (of ancestor) at each level.
paths = tree.paths_padded(method='self', exclude_root=True)
label_to_level_target = np.full([num_levels, len(label_to_node)], -1, dtype=int)
for i in range(num_levels):
label_to_level_target[i, :] = node_to_level_target[i, paths[label_to_node, i]]
# Every label should correspond to a valid target.
# Note that this will fail for non-leaf targets.
# TODO: Implement no loss for descendants of label? (for non-leaf targets)
assert np.all(label_to_level_target[i, :] >= 0)
assert np.all(label_to_level_target[i, :] < level_sizes[i])
label_to_level_target = torch.from_numpy(label_to_level_target).to(scores.device)
level_scores = torch.split(scores, level_sizes, dim=-1)
level_targets = label_to_level_target[:, labels].unbind(0)
# Dense, padded implementation:
# level_scores = _split_and_pad(scores, level_sizes, -torch.inf, dim=-1)
# level_logp = F.log_softmax(level_scores, dim=-1)
# level_nll = -_gather_2d(level_logp, j=level_targets)
level_nll = torch.stack(
[F.cross_entropy(x, y, reduction='none') for x, y in zip(level_scores, level_targets)],
dim=-1)
mean_nll = torch.sum(level_nll, dim=-1)
return torch.mean(mean_nll)
class LevelwiseSoftmaxNLL(nn.Module):
def __init__(self, tree: hier.Hierarchy, with_leaf_targets: bool = False, reduction: str = 'mean'):
super().__init__()
assert reduction == 'mean'
level_nodes = hier.level_nodes(tree, extend=True)
level_sizes = tuple(map(len, level_nodes))
num_levels = len(level_nodes)
# Construct map from node to index within softmax at each level.
node_to_level_target = np.full([num_levels, tree.num_nodes()], -1, dtype=int)
for i in range(num_levels):
node_to_level_target[i, level_nodes[i]] = np.arange(level_sizes[i])
# Get label set.
if with_leaf_targets:
label_to_node = tree.leaf_subset()
else:
label_to_node = np.arange(tree.num_nodes())
# Find mapping from label to index within softmax (of ancestor) at each level.
paths = tree.paths_padded(method='self', exclude_root=True)
label_to_level_target = np.full([num_levels, len(label_to_node)], -1, dtype=int)
for i in range(num_levels):
label_to_level_target[i, :] = node_to_level_target[i, paths[label_to_node, i]]
# Every label should correspond to a valid target.
# Note that this will fail for non-leaf targets.
# TODO: Implement no loss for descendants of label? (for non-leaf targets)
assert np.all(label_to_level_target[i, :] >= 0)
assert np.all(label_to_level_target[i, :] < level_sizes[i])
self.level_sizes = level_sizes
self.label_to_level_target = torch.from_numpy(label_to_level_target)
def _apply(self, fn):
super()._apply(fn)
self.label_to_level_target = fn(self.label_to_level_target)
return self
def forward(self, scores: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
level_scores = torch.split(scores, self.level_sizes, dim=-1)
level_targets = self.label_to_level_target[:, labels].unbind(0)
# Dense, padded implementation:
# level_scores = _split_and_pad(scores, level_sizes, -torch.inf, dim=-1)
# level_logp = F.log_softmax(level_scores, dim=-1)
# level_nll = -_gather_2d(level_logp, j=level_targets)
level_nll = torch.stack(
[F.cross_entropy(x, y, reduction='none') for x, y in zip(level_scores, level_targets)],
dim=-1)
mean_nll = torch.sum(level_nll, dim=-1)
return torch.mean(mean_nll)
def levelwise_log_softmax(tree: hier.Hierarchy, scores: torch.Tensor) -> torch.Tensor:
batch_dims = tuple(scores.shape[:-1])
num_nodes = tree.num_nodes()
level_nodes = hier.level_nodes(tree, extend=True)
level_sizes = tuple(map(len, level_nodes))
level_scores = torch.split(scores, level_sizes, dim=-1)
level_logp = [F.log_softmax(x, dim=-1) for x in level_scores]
# Take probability of each node at its level.
# Note that node may appear at multiple levels.
# Perform in reverse order such that shallow overwrites deep.
logp = torch.full((*batch_dims, num_nodes), -torch.inf, dtype=scores.dtype, device=scores.device)
for i in reversed(range(len(level_nodes))):
logp = logp.index_copy(-1, level_nodes[i], level_logp[i])
logp = logp.index_fill(-1, torch.zeros((), dtype=int, device=scores.device), 0.)
return logp
class LevelwiseLogSoftmax(nn.Module):
def __init__(self, tree: hier.Hierarchy):
super().__init__()
num_nodes = tree.num_nodes()
level_nodes = hier.level_nodes(tree, extend=True)
level_sizes = tuple(map(len, level_nodes))
self.num_nodes = num_nodes
self.level_sizes = level_sizes
self.level_nodes = list(map(torch.from_numpy, level_nodes))
def _apply(self, fn):
super()._apply(fn)
self.level_nodes = list(map(fn, self.level_nodes))
return self
def forward(self, scores: torch.Tensor) -> torch.Tensor:
num_levels = len(self.level_nodes)
batch_dims = tuple(scores.shape[:-1])
level_scores = torch.split(scores, self.level_sizes, dim=-1)
level_logp = [F.log_softmax(x, dim=-1) for x in level_scores]
# Take probability of each node at its level.
# Note that node may appear at multiple levels.
# Perform in reverse order such that shallow overwrites deep.
logp = torch.full((*batch_dims, self.num_nodes), -torch.inf, dtype=scores.dtype, device=scores.device)
for i in reversed(range(num_levels)):
logp = logp.index_copy(-1, self.level_nodes[i], level_logp[i])
logp = logp.index_fill(-1, torch.zeros((), dtype=int, device=scores.device), 0.)
return logp
# def descendant_max(tree: hier.Hierarchy, x: torch.Tensor) -> torch.Tensor:
# # TODO: Avoid high memory usage.
# # Maybe this could be done with a (parent[i], children[i, j]) array per depth,
# # where children is padded.
# is_ancestor = torch.from_numpy(tree.ancestor_mask()).to(device=x.device)
# neg_inf = torch.full((), -torch.inf, dtype=x.dtype, device=x.device)
# descendant_values = torch.where(is_ancestor, x.unsqueeze(-1), neg_inf)
# max_value, _ = torch.max(descendant_values, dim=-1)
# return max_value
def descendant_max(tree: hier.Hierarchy, x: torch.Tensor) -> torch.Tensor:
"""Finds the max over all descendants of each node."""
level_parents, level_children = hier.level_successors_padded(tree, method='last')
level_parents = list(map(torch.from_numpy, level_parents))
level_children = list(map(torch.from_numpy, level_children))
num_levels = len(level_parents)
out = x
# Work up from the maximum depth, taking max of node and its children.