Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implementing gather op #433

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Re-implementing gather op #433

wants to merge 1 commit into from

Conversation

0x45f
Copy link
Collaborator

@0x45f 0x45f commented Jan 22, 2025

PR Category

Operator

Type of Change

Performance Optimization

Description

Re-implementing gather op

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Before:

Operator: gather  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.013024            0.011232               1.160               7.548               8.752          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.013184            0.018400               0.717             119.301              85.482          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.032864            0.139904               0.235             765.757             179.879          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.332608            8.644352               0.038            1210.594              46.580          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.523776           54.445824               0.028            1056.988              29.582          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.062144            0.230880               0.269             988.671             266.112          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              14.877792          572.946045               0.026            1057.189              27.452          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]


Operator: gather  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011744            0.008480               1.385              11.161              15.457          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.011616            0.016960               0.685             180.540             123.653          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.038336            0.143840               0.267             875.272             233.276          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.387040            9.053536               0.043            1387.120              59.300          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.671840           54.217888               0.031            1284.503              39.608          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.069696            0.254976               0.273            1175.390             321.285          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              16.339872          569.732117               0.029            1283.457              36.809          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]


Operator: gather  Performance Test (dtype=torch.bfloat16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.010944            0.008608               1.271               8.982              11.420          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.012384            0.018400               0.673             127.008              85.482          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.033760            0.141472               0.239             745.433             177.886          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.332864            8.636352               0.039            1209.663              46.623          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.524480           54.436993               0.028            1056.500              29.587          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.062240            0.231840               0.268             987.147             265.010          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              14.870400          572.917847               0.026            1057.715              27.454          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]

After:

Operator: gather  Performance Test (dtype=torch.float16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.013152            0.009536               1.379               7.474              10.309          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.012864            0.009664               1.331             122.269             162.755          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.033632            0.031328               1.074             748.270             803.301          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.333440            0.307360               1.085            1207.573            1310.038          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.524736            1.411264               1.080            1056.322            1141.255          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.062272            0.058784               1.059             986.639            1045.182          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              14.870304           13.689120               1.086            1057.721            1148.988          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]


Operator: gather  Performance Test (dtype=torch.float32, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011328            0.008256               1.372              11.571              15.876          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.011776            0.009888               1.191             178.087             212.091          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.039008            0.034944               1.116             860.194             960.234          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.387648            0.376640               1.029            1384.944            1425.422          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.672064            1.616960               1.034            1284.331            1328.099          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.070688            0.066208               1.068            1158.895            1237.313          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              16.339647           15.749472               1.037            1283.474            1331.570          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]


Operator: gather  Performance Test (dtype=torch.bfloat16, mode=cuda,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.011936            0.008064               1.480               8.236              12.190          [torch.Size([64, 64]), -1, torch.Size([64, 128])]
SUCCESS               0.012032            0.010720               1.122             130.723             146.722          [torch.Size([256, 256]), -1, torch.Size([256, 512])]
SUCCESS               0.033792            0.030240               1.117             744.727             832.203          [torch.Size([1024, 1024]), -1, torch.Size([1024, 2048])]
SUCCESS               0.332928            0.308128               1.080            1209.430            1306.772          [torch.Size([4096, 4096]), -1, torch.Size([4096, 8192])]
SUCCESS               1.524384            1.409856               1.081            1056.566            1142.395          [torch.Size([1024, 65536]), -1, torch.Size([1024, 131072])]
SUCCESS               0.061504            0.058944               1.043             998.959            1042.345          [torch.Size([10000, 256]), -1, torch.Size([10000, 512])]
SUCCESS              14.872288           13.688032               1.087            1057.580            1149.080          [torch.Size([10000, 65536]), -1, torch.Size([10000, 131072])]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant