Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MultilabelSoftMarginLoss for small C #3451

Open
wants to merge 23 commits into
base: develop
Choose a base branch
from

Conversation

littlecutebird
Copy link
Collaborator

  • Add MultilabelSoftMarginLoss forward operation and kernel. Backward is not better compared to ROCm in general.
  • Given input tensor is (N,C), MIOpen is better than ROCm if C <= 24.
  • New API is guarded by MIOPEN_BETA_API macro.
  • Added driver test and gtest for MultilabelSoftMarginLoss.

Unreduced:

type Forward
float32 4.29
float16 5.18
bfloat16 5.56
fp32
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 75096 10310 7.28
22 12 cont 44229 9530 4.64
75 19 uncont 81192 14701 5.52
75 19 cont 56166 12658 4.44
33 4 uncont 57349 7502 7.64
33 4 cont 39779 7023 5.66
54 7 uncont 69414 8799 7.89
54 7 cont 43333 8267 5.24
87 23 uncont 72487 16230 4.47
87 23 cont 43380 14988 2.89
10 3 uncont 55270 6382 8.66
10 3 cont 39092 6240 6.26
341 11 uncont 72247 14026 5.15
341 11 cont 52630 10721 4.91
564 17 uncont 77863 21101 3.69
564 17 cont 50517 17193 2.94
289 2 uncont 82376 6897 11.94
289 2 cont 48517 6525 7.44
456 8 uncont 70966 11644 6.09
456 8 cont 51269 9174 5.59
711 15 uncont 71912 19572 3.67
711 15 cont 47701 15414 3.09
987 22 uncont 64823 22452 2.89
987 22 cont 48036 22206 2.16
1324 6 uncont 70775 9208 7.69
1324 6 cont 50566 8071 6.27
9456 13 uncont 68294 17546 3.89
9456 13 cont 53222 13868 3.84
7532 20 uncont 74919 20781 3.61
7532 20 cont 57973 19646 2.95
8451 14 uncont 69190 18061 3.83
8451 14 cont 54470 14187 3.84
2964 21 uncont 67462 22043 3.06
2964 21 cont 50342 21726 2.32
4987 1 uncont 52054 6666 7.81
4987 1 cont 46725 6311 7.40
15432 10 uncont 75144 16479 4.56
15432 10 cont 57302 11218 5.11
29876 18 uncont 126028 44070 2.86
29876 18 cont 103834 42563 2.44
73915 5 uncont 97498 19057 5.12
73915 5 cont 83881 11894 7.05
58241 9 uncont 117339 39412 2.98
58241 9 cont 99467 21957 4.53
19432 16 uncont 87305 25705 3.40
19432 16 cont 72168 24250 2.98
87009 7 uncont 121644 40194 3.03
87009 7 cont 101706 20179 5.04
123456 24 uncont 407737 323724 1.26
123456 24 cont 361812 262560 1.38
543210 12 uncont 880663 679377 1.30
543210 12 cont 749146 314989 2.38
389124 19 uncont 1004881 946215 1.06
389124 19 cont 877893 683687 1.28
678234 11 uncont 1001103 796672 1.26
678234 11 cont 862018 271980 3.17
912345 14 uncont 1610678 1511320 1.07
912345 14 cont 1370831 754619 1.82
456789 8 uncont 495373 297754 1.66
456789 8 cont 432343 102228 4.23
fp16
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 75047 10115 7.42
22 12 cont 45956 9672 4.75
75 19 uncont 83224 12586 6.61
75 19 cont 58373 11966 4.88
33 4 uncont 57478 7448 7.72
33 4 cont 38372 7200 5.33
54 7 uncont 70423 8871 7.94
54 7 cont 44116 7698 5.73
87 23 uncont 78584 15501 5.07
87 23 cont 45397 13014 3.49
10 3 uncont 51733 6613 7.82
10 3 cont 39860 6738 5.92
341 11 uncont 76424 10577 7.23
341 11 cont 56022 9547 5.87
564 17 uncont 85480 17439 4.90
564 17 cont 51557 12997 3.97
289 2 uncont 85000 7341 11.58
289 2 cont 52614 6276 8.38
456 8 uncont 72407 9919 7.30
456 8 cont 55542 9441 5.88
711 15 uncont 75000 15679 4.78
711 15 cont 46580 11112 4.19
987 22 uncont 64518 22417 2.88
987 22 cont 46709 16037 2.91
1324 6 uncont 72150 8302 8.69
1324 6 cont 50597 8107 6.24
9456 13 uncont 62006 13475 4.60
9456 13 cont 51189 9867 5.19
7532 20 uncont 65735 20301 3.24
7532 20 cont 54870 14579 3.76
8451 14 uncont 63335 14701 4.31
8451 14 cont 53222 10205 5.22
2964 21 uncont 64792 22026 2.94
2964 21 cont 50020 15005 3.33
4987 1 uncont 53558 6631 8.08
4987 1 cont 48884 6382 7.66
15432 10 uncont 65447 11715 5.59
15432 10 cont 54405 9636 5.65
29876 18 uncont 104891 43003 2.44
29876 18 cont 89705 23041 3.89
73915 5 uncont 84057 12373 6.79
73915 5 cont 72551 10542 6.88
58241 9 uncont 99402 22666 4.39
58241 9 cont 85993 13761 6.25
19432 16 uncont 80808 24799 3.26
19432 16 cont 65030 13316 4.88
87009 7 uncont 101242 20745 4.88
87009 7 cont 87993 13636 6.45
123456 24 uncont 340658 258180 1.32
123456 24 cont 296078 124346 2.38
543210 12 uncont 666273 256651 2.60
543210 12 cont 581993 97748 5.95
389124 19 uncont 788876 603130 1.31
389124 19 cont 690722 204617 3.38
678234 11 uncont 775865 249505 3.11
678234 11 cont 673872 98174 6.86
912345 14 uncont 1223601 584749 2.09
912345 14 cont 1051425 242324 4.34
456789 8 uncont 411589 100762 4.08
456789 8 cont 359936 58260 6.18
bfp16
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 79929 10168 7.86
22 12 cont 48998 9334 5.25
75 19 uncont 88984 12710 7.00
75 19 cont 64791 11912 5.44
33 4 uncont 53125 7146 7.43
33 4 cont 39380 7058 5.58
54 7 uncont 72343 8604 8.41
54 7 cont 46788 8143 5.75
87 23 uncont 84841 15181 5.59
87 23 cont 50133 12926 3.88
10 3 uncont 50885 6648 7.65
10 3 cont 39732 6809 5.84
341 11 uncont 85064 10701 7.95
341 11 cont 62502 9583 6.52
564 17 uncont 88585 17439 5.08
564 17 cont 56982 12748 4.47
289 2 uncont 83401 7217 11.56
289 2 cont 53941 6436 8.38
456 8 uncont 80648 9546 8.45
456 8 cont 59126 9263 6.38
711 15 uncont 81032 15644 5.18
711 15 cont 52902 10881 4.86
987 22 uncont 73127 22363 3.27
987 22 cont 52677 15415 3.42
1324 6 uncont 77497 8408 9.22
1324 6 cont 58070 7822 7.42
9456 13 uncont 75031 13564 5.53
9456 13 cont 54533 9885 5.52
7532 20 uncont 76344 19946 3.83
7532 20 cont 58598 14134 4.15
8451 14 uncont 74823 14559 5.14
8451 14 cont 57254 10258 5.58
2964 21 uncont 76375 21528 3.55
2964 21 cont 55365 14525 3.81
4987 1 uncont 49924 6542 7.63
4987 1 cont 51333 6293 8.16
15432 10 uncont 77704 11661 6.66
15432 10 cont 58999 9103 6.48
29876 18 uncont 112364 42612 2.64
29876 18 cont 94618 23415 4.04
73915 5 uncont 92682 12443 7.45
73915 5 cont 73975 10489 7.05
58241 9 uncont 107067 22506 4.76
58241 9 cont 88249 13512 6.53
19432 16 uncont 86952 24870 3.50
19432 16 cont 70278 12996 5.41
87009 7 uncont 112282 20710 5.42
87009 7 cont 92666 13600 6.81
123456 24 uncont 352739 257948 1.37
123456 24 cont 314703 124310 2.53
543210 12 uncont 697716 255300 2.73
543210 12 cont 615900 97712 6.30
389124 19 uncont 823999 598348 1.38
389124 19 cont 729094 204901 3.56
678234 11 uncont 806556 249771 3.23
678234 11 cont 715427 97978 7.30
912345 14 uncont 1285141 584376 2.20
912345 14 cont 1121334 240528 4.66
456789 8 uncont 451545 100513 4.49
456789 8 cont 383602 58616 6.54

Reduced:

type Forward
float32 6.13
float16 6.60
bfloat16 6.59
fp32
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 199700 23501 8.50
22 12 cont 195252 23007 8.49
75 19 uncont 202004 28692 7.04
75 19 cont 188531 26670 7.07
33 4 uncont 199732 20834 9.59
33 4 cont 199412 20500 9.73
54 7 uncont 206917 23021 8.99
54 7 cont 192163 21816 8.81
87 23 uncont 196548 31483 6.24
87 23 cont 196196 28892 6.79
10 3 uncont 198068 24514 8.08
10 3 cont 185107 20606 8.98
341 11 uncont 190564 32834 5.80
341 11 cont 192211 29674 6.48
564 17 uncont 193780 38558 5.03
564 17 cont 187475 37550 4.99
289 2 uncont 197956 25439 7.78
289 2 cont 195492 31967 6.12
456 8 uncont 194596 30399 6.40
456 8 cont 190147 29140 6.53
711 15 uncont 192,627 36212 5.32
711 15 cont 187,475 33300 5.63
987 22 uncont 232,408 39660 5.86
987 22 cont 207,173 41230 5.02
1324 6 uncont 195,508 30363 6.44
1324 6 cont 196,484 29033 6.77
9456 13 uncont 194,612 34256 5.68
9456 13 cont 197,412 30633 6.44
7532 20 uncont 378,550 38238 9.90
7532 20 cont 309,344 38367 8.06
8451 14 uncont 263611 34914 7.55
8451 14 cont 372054 33265 11.18
2964 21 uncont 354676 39874 8.89
2964 21 cont 242952 38687 6.28
4987 1 uncont 253658 28194 9.00
4987 1 cont 526837 29780 17.69
15432 10 uncont 300670 33385 9.01
15432 10 cont 234151 30473 7.68
29876 18 uncont 299534 60442 4.96
29876 18 cont 280604 61142 4.59
73915 5 uncont 302190 39127 7.72
73915 5 cont 318480 36127 8.82
58241 9 uncont 254186 56922 4.47
58241 9 cont 238200 39896 5.97
19432 16 uncont 245033 43732 5.60
19432 16 cont 246601 46563 5.30
87009 7 uncont 273483 62256 4.39
87009 7 cont 209333 39344 5.32
123456 24 uncont 431915 346746 1.25
123456 24 cont 386182 286064 1.35
543210 12 uncont 908441 704994 1.29
543210 12 cont 764459 345214 2.21
389124 19 uncont 1029587 967742 1.06
389124 19 cont 904487 709786 1.27
678234 11 uncont 1021409 823427 1.24
678234 11 cont 881075 297173 2.96
912345 14 uncont 1635399 1538590 1.06
912345 14 cont 1392769 784630 1.78
456789 8 uncont 517983 320294 1.62
456789 8 cont 451833 127385 3.55
fp16
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 197796 25759 7.68
22 12 cont 191187 24874 7.69
75 19 uncont 204469 26594 7.69
75 19 cont 188963 26901 7.02
33 4 uncont 197652 20763 9.52
33 4 cont 198676 20909 9.50
54 7 uncont 209957 21634 9.70
54 7 cont 191604 21744 8.81
87 23 uncont 193683 32532 5.95
87 23 cont 196932 29905 6.59
10 3 uncont 198916 21830 9.11
10 3 cont 182787 21051 8.68
341 11 uncont 189251 30185 6.27
341 11 cont 193412 28518 6.78
564 17 uncont 199828 35518 5.63
564 17 cont 190115 30688 6.20
289 2 uncont 192708 27608 6.98
289 2 cont 197108 27362 7.20
456 8 uncont 197732 29048 6.81
456 8 cont 186883 31860 5.87
711 15 uncont 189,939 34185 5.56
711 15 cont 197,636 28233 7.00
987 22 uncont 240,776 40727 5.91
987 22 cont 205,205 37692 5.44
1324 6 uncont 191,492 26843 7.13
1324 6 cont 392,680 26633 14.74
9456 13 uncont 198,180 30701 6.46
9456 13 cont 191,955 27949 6.87
7532 20 uncont 350,708 40105 8.74
7532 20 cont 299,246 37656 7.95
8451 14 uncont 262266 34416 7.62
8451 14 cont 324913 33353 9.74
2964 21 uncont 356804 39590 9.01
2964 21 cont 253690 32980 7.69
4987 1 uncont 386103 27963 13.81
4987 1 cont 387271 33122 11.69
15432 10 uncont 252185 38167 6.61
15432 10 cont 240552 32464 7.41
29876 18 uncont 228727 63696 3.59
29876 18 cont 233848 42172 5.55
73915 5 uncont 238392 34950 6.82
73915 5 cont 235928 40340 5.85
58241 9 uncont 237288 39678 5.98
58241 9 cont 263979 29708 8.89
19432 16 uncont 244200 42949 5.69
19432 16 cont 250665 34864 7.19
87009 7 uncont 211845 44461 4.76
87009 7 cont 203412 41158 4.94
123456 24 uncont 362180 279121 1.30
123456 24 cont 321696 145058 2.22
543210 12 uncont 688547 283068 2.43
543210 12 cont 605595 123901 4.89
389124 19 uncont 812574 624126 1.30
389124 19 cont 709108 229774 3.09
678234 11 uncont 794202 275619 2.88
678234 11 cont 695393 124114 5.60
912345 14 uncont 1247250 611878 2.04
912345 14 cont 1073858 270432 3.97
456789 8 uncont 431623 127659 3.38
456789 8 cont 385091 83293 4.62
bfp16
input_size num class contiguous ROCm MIOpen Improvement
22 12 uncont 200820 23625 8.50
22 12 cont 191539 22402 8.55
75 19 uncont 202436 26719 7.58
75 19 cont 198532 26047 7.62
33 4 uncont 201284 20177 9.98
33 4 cont 198980 27612 7.21
54 7 uncont 197556 23128 8.54
54 7 cont 191443 20784 9.21
87 23 uncont 196292 28976 6.77
87 23 cont 193267 29532 6.54
10 3 uncont 195764 21048 9.30
10 3 cont 191508 21300 8.99
341 11 uncont 191219 29705 6.44
341 11 cont 195908 31132 6.29
564 17 uncont 199348 39074 5.10
564 17 cont 191779 35737 5.37
289 2 uncont 196852 26772 7.35
289 2 cont 193156 27593 7.00
456 8 uncont 198772 28727 6.92
456 8 cont 190307 29282 6.50
711 15 uncont 191155 32923 5.81
711 15 cont 195284 36963 5.28
987 22 uncont 245673 40034 6.14
987 22 cont 201444 36145 5.57
1324 6 uncont 196820 31785 6.19
1324 6 cont 195396 38528 5.07
9456 13 uncont 214614 32656 6.57
9456 13 cont 303679 35203 8.63
7532 20 uncont 268907 36941 7.28
7532 20 cont 326129 34491 9.46
8451 14 uncont 261034 32781 7.96
8451 14 cont 245673 39772 6.18
2964 21 uncont 298718 39661 7.53
2964 21 cont 359524 32784 10.97
4987 1 uncont 499538 27643 18.07
4987 1 cont 238200 31629 7.53
15432 10 uncont 268235 32639 8.22
15432 10 cont 241128 30668 7.86
29876 18 uncont 243913 62007 3.93
29876 18 cont 235912 42812 5.51
73915 5 uncont 280700 33047 8.49
73915 5 cont 282476 39611 7.13
58241 9 uncont 257388 41794 6.16
58241 9 cont 275531 30597 9.01
19432 16 uncont 274699 42096 6.53
19432 16 cont 259258 38900 6.66
87009 7 uncont 270427 42576 6.35
87009 7 cont 227783 40625 5.61
123456 24 uncont 391655 277983 1.41
123456 24 cont 346850 143796 2.41
543210 12 uncont 718838 283068 2.54
543210 12 cont 636510 126318 5.04
389124 19 uncont 846529 621904 1.36
389124 19 cont 756168 228387 3.31
678234 11 uncont 824669 277060 2.98
678234 11 cont 735765 125892 5.84
912345 14 uncont 1309143 614295 2.13
912345 14 cont 1144712 268707 4.26
456789 8 uncont 453449 125579 3.61
456789 8 cont 403940 81781 4.94

@littlecutebird
Copy link
Collaborator Author

@CAHEK7 I cannot add you as a reviewer for this PR 😅 So I have tagged to send a notification to you
image

Copy link
Contributor

@CAHEK7 CAHEK7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked only the kernel, the tests and solver code (partially). It looks good.

@CAHEK7
Copy link
Contributor

CAHEK7 commented Jan 7, 2025

@CAHEK7 I cannot add you as a reviewer for this PR 😅 So I have tagged to send a notification to you

That's because I'm not a part of the team now. But I'm checking some code from time to time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants