Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Operator] index_add optimized #427

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open

[Operator] index_add optimized #427

wants to merge 3 commits into from

Conversation

GwokHiujin
Copy link
Collaborator

PR Category

Operator

Type of Change

Bug Fix & Performance Optimization

Description

The method of using indexing to implement index_add resolves the bugs present in the previous version, and I accelerate it using code generation.

I want to point out that I use tl.atomic_add in the kernel, which currently doesn't support bfloat16, so I am temporarily skipping this data type during testing.

BTW this solution doesn't perform well enough in some shapes, thus any suggestion is appreciated :)

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Before: (Bug in shape (200, 40999, 3))

test_select_and_slice_perf.py 
Operator: index_add  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011264            0.272384               0.041               0.750               0.031          [torch.Size([64, 64]), 1, torch.Size([32]), torch.Size([64, 32])]
SUCCESS               0.012288            0.276480               0.044              10.750               0.478          [torch.Size([256, 256]), 1, torch.Size([128]), torch.Size([256, 128])]
SUCCESS               0.027648            0.279552               0.099              76.000               7.516          [torch.Size([1024, 1024]), 1, torch.Size([512]), torch.Size([1024, 512])]
SUCCESS               0.276480            0.552960               0.500             121.422              60.711          [torch.Size([4096, 4096]), 1, torch.Size([2048]), torch.Size([4096, 2048])]
SUCCESS               1.100800            4.300800               0.256             122.166              31.269          [torch.Size([1024, 65536]), 1, torch.Size([32768]), torch.Size([1024, 32768])]
SUCCESS              24.657921           25.936895               0.951              65.318              62.097          [torch.Size([268435456]), 0, torch.Size([134217728]), torch.Size([134217728])]
SUCCESS               0.011264            0.270336               0.042               3.552               0.148          [torch.Size([10000, 1]), 1, torch.Size([1]), torch.Size([10000, 1])]
SUCCESS               0.043008            0.278528               0.154             119.071              18.386          [torch.Size([10000, 256]), 1, torch.Size([128]), torch.Size([10000, 128])]
SUCCESS              26.708992           60.860416               0.439              49.084              21.541          [torch.Size([10000, 65536]), 1, torch.Size([32768]), torch.Size([10000, 32768])]
SUCCESS               0.011264            0.273408               0.041               3.552               0.146          [torch.Size([100, 1, 100]), 1, torch.Size([1]), torch.Size([100, 1, 100])]
SUCCESS               0.044032            0.278528               0.158             116.302              18.386          [torch.Size([100, 256, 100]), 1, torch.Size([128]), torch.Size([100, 128, 100])]
SUCCESS               9.287680           68.525055               0.136             141.153              19.131          [torch.Size([100, 65536, 100]), 1, torch.Size([32768]), torch.Size([100, 32768, 100])]


Operator: index_add  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011264            0.273408               0.041               1.477               0.061          [torch.Size([64, 64]), 1, torch.Size([32]), torch.Size([64, 32])]
SUCCESS               0.011264            0.276480               0.041              23.364               0.952          [torch.Size([256, 256]), 1, torch.Size([128]), torch.Size([256, 128])]
SUCCESS               0.028672            0.278528               0.103             146.429              15.074          [torch.Size([1024, 1024]), 1, torch.Size([512]), torch.Size([1024, 512])]
SUCCESS               0.296960            0.715776               0.415             226.041              93.780          [torch.Size([4096, 4096]), 1, torch.Size([2048]), torch.Size([4096, 2048])]
SUCCESS               1.985536            5.766144               0.344             135.327              46.599          [torch.Size([1024, 65536]), 1, torch.Size([32768]), torch.Size([1024, 32768])]
SUCCESS              26.319872           28.300287               0.930              81.592              75.882          [torch.Size([268435456]), 0, torch.Size([134217728]), torch.Size([134217728])]
SUCCESS               0.011264            0.273408               0.041               7.103               0.293          [torch.Size([10000, 1]), 1, torch.Size([1]), torch.Size([10000, 1])]
SUCCESS               0.040960            0.280576               0.146             250.025              36.500          [torch.Size([10000, 256]), 1, torch.Size([128]), torch.Size([10000, 128])]
SUCCESS              51.834881           65.326080               0.793              50.578              40.133          [torch.Size([10000, 65536]), 1, torch.Size([32768]), torch.Size([10000, 32768])]
SUCCESS               0.011264            0.273408               0.041               7.103               0.293          [torch.Size([100, 1, 100]), 1, torch.Size([1]), torch.Size([100, 1, 100])]
SUCCESS               0.043008            0.277504               0.155             238.119              36.904          [torch.Size([100, 256, 100]), 1, torch.Size([128]), torch.Size([100, 128, 100])]
SUCCESS               8.082432           69.867523               0.116             324.370              37.524          [torch.Size([100, 65536, 100]), 1, torch.Size([32768]), torch.Size([100, 32768, 100])]

After:

test_select_and_slice_perf.py 
Operator: index_add  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011264            0.011264               1.000               0.750               0.750          [torch.Size([64, 64]), 1, torch.Size([32]), torch.Size([64, 32])]
SUCCESS               0.012288            0.022528               0.545              10.750               5.864          [torch.Size([256, 256]), 1, torch.Size([128]), torch.Size([256, 128])]
SUCCESS               0.027648            0.062464               0.443              76.000              33.639          [torch.Size([1024, 1024]), 1, torch.Size([512]), torch.Size([1024, 512])]
SUCCESS               0.276480            0.507904               0.544             121.422              66.097          [torch.Size([4096, 4096]), 1, torch.Size([2048]), torch.Size([4096, 2048])]
SUCCESS               1.101824            4.257792               0.259             122.052              31.584          [torch.Size([1024, 65536]), 1, torch.Size([32768]), torch.Size([1024, 32768])]
SUCCESS              24.667135           23.766016               1.038              65.294              67.770          [torch.Size([268435456]), 0, torch.Size([134217728]), torch.Size([134217728])]
SUCCESS               0.012288            0.011264               1.091               3.256               3.552          [torch.Size([10000, 1]), 1, torch.Size([1]), torch.Size([10000, 1])]
SUCCESS               0.043008            0.101376               0.424             119.071              50.515          [torch.Size([10000, 256]), 1, torch.Size([128]), torch.Size([10000, 128])]
SUCCESS              25.495552           60.827648               0.419              51.420              21.552          [torch.Size([10000, 65536]), 1, torch.Size([32768]), torch.Size([10000, 32768])]
SUCCESS               0.012288            0.015360               0.800               3.256               2.605          [torch.Size([100, 1, 100]), 1, torch.Size([1]), torch.Size([100, 1, 100])]
SUCCESS               0.045056            0.131072               0.344             113.659              39.070          [torch.Size([100, 256, 100]), 1, torch.Size([128]), torch.Size([100, 128, 100])]
SUCCESS               9.285632           68.473854               0.136             141.184              19.146          [torch.Size([100, 65536, 100]), 1, torch.Size([32768]), torch.Size([100, 32768, 100])]


Operator: index_add  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011264            0.010240               1.100               1.477               1.625          [torch.Size([64, 64]), 1, torch.Size([32]), torch.Size([64, 32])]
SUCCESS               0.011264            0.012288               0.917              23.364              21.417          [torch.Size([256, 256]), 1, torch.Size([128]), torch.Size([256, 128])]
SUCCESS               0.028672            0.035840               0.800             146.429             117.143          [torch.Size([1024, 1024]), 1, torch.Size([512]), torch.Size([1024, 512])]
SUCCESS               0.296960            0.684032               0.434             226.041              98.132          [torch.Size([4096, 4096]), 1, torch.Size([2048]), torch.Size([4096, 2048])]
SUCCESS               1.982464            5.712896               0.347             135.537              47.034          [torch.Size([1024, 65536]), 1, torch.Size([32768]), torch.Size([1024, 32768])]
SUCCESS              26.341375           26.116096               1.009              81.525              82.228          [torch.Size([268435456]), 0, torch.Size([134217728]), torch.Size([134217728])]
SUCCESS               0.011264            0.010240               1.100               7.103               7.813          [torch.Size([10000, 1]), 1, torch.Size([1]), torch.Size([10000, 1])]
SUCCESS               0.040960            0.066560               0.615             250.025             153.862          [torch.Size([10000, 256]), 1, torch.Size([128]), torch.Size([10000, 128])]
SUCCESS              51.115009           65.212418               0.784              51.290              40.202          [torch.Size([10000, 65536]), 1, torch.Size([32768]), torch.Size([10000, 32768])]
SUCCESS               0.011264            0.011264               1.000               7.103               7.103          [torch.Size([100, 1, 100]), 1, torch.Size([1]), torch.Size([100, 1, 100])]
SUCCESS               0.043008            0.067584               0.636             238.119             151.530          [torch.Size([100, 256, 100]), 1, torch.Size([128]), torch.Size([100, 128, 100])]
SUCCESS               8.084480           69.749763               0.116             324.288              37.587          [torch.Size([100, 65536, 100]), 1, torch.Size([32768]), torch.Size([100, 32768, 100])]

code.writeline(
"add_on = tl.load(src + src_idx, mask=mask, other=0) * alpha"
)
code.writeline("tl.atomic_add(out + input_idx, add_on, mask=input_mask)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's safe to set sem='relaxed'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants