Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GEMM/dot for FP8 using hipblasLT #3577

Merged
merged 20 commits into from
Nov 13, 2024
Merged

Enable GEMM/dot for FP8 using hipblasLT #3577

merged 20 commits into from
Nov 13, 2024

Conversation

CharlieL7
Copy link
Collaborator

Looks to work correctly on Navi4X

@CharlieL7 CharlieL7 added simple small or simple changes FP8 issues related to FP8 implemenation labels Oct 30, 2024
@CharlieL7 CharlieL7 requested a review from ahsan-ca October 30, 2024 20:31
@CharlieL7 CharlieL7 self-assigned this Oct 30, 2024
@@ -129,9 +129,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_fp8e4m3fnuz_ops.insert("argmin");

std::set<std::string> unsupported_fp8ocp_ops = {};
// TODO update with hipBLASLt support
// TODO: remove this when the flag is removed
#if !MIGRAPHX_ENABLE_HIPBLASLT_GEMM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need to do MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) at the top of the file...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what that does exactly. Is the env variable not declared elsewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the way MIGraphX handles env vars. Doing this provides functions like enabled/disabled etc to work with the env vars, so we can set any of the following values for the env variable:

Set to "1", "enable", "enabled", "yes", or "true" to use.

Also, I missed to mention that we should do if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) instead of !MIGRAPHX_ENABLE_HIPBLASLT_GEMM, for guarding using the env variable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this environment variable not clash with the one already declared in lowering.cpp? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think they should clash.

Copy link

codecov bot commented Oct 30, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.17%. Comparing base (1cfd6c2) to head (519a63b).
Report is 5 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3577      +/-   ##
===========================================
- Coverage    92.17%   92.17%   -0.01%     
===========================================
  Files          513      513              
  Lines        21560    21558       -2     
===========================================
- Hits         19873    19871       -2     
  Misses        1687     1687              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@CharlieL7 CharlieL7 marked this pull request as ready for review November 1, 2024 16:25
@CharlieL7 CharlieL7 requested a review from causten as a code owner November 1, 2024 16:25
Copy link
Contributor

@lakhinderwalia lakhinderwalia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. Is there any test case that verifies this set of Ops. Thanks.

@CharlieL7
Copy link
Collaborator Author

Approved. Is there any test case that verifies this set of Ops. Thanks.

Yes, there are the test_verify gemm tests that will use hipblaslt now if the flag is enabled.

@CharlieL7
Copy link
Collaborator Author

Created an issue tracking the preprocessor conditional I added: #3592

@CharlieL7 CharlieL7 linked an issue Nov 6, 2024 that may be closed by this pull request
@CharlieL7
Copy link
Collaborator Author

I tried using CMake's check_type_size or check_symbol_exists to detect if the types are in the hipblaslt version. The HIP_R_F8...are not really types so it doesn't work. I didn't get check_symbol_exists to work. The last thing I found was using check_source_compiles but that's way too troublesome to use.

@pfultz2
Copy link
Collaborator

pfultz2 commented Nov 8, 2024

The types are defined as enum in hip not in hipblaslt:

https://github.com/ROCm/HIP/blob/3d60bd3a6415c280bd1fe63767ae8e10eea4d2d1/include/hip/library_types.h#L61

So we shouldn't check the hipblaslt version. It looks like hipblaslt already checks the hip version and then defines ROCM_USE_FLOAT8:

https://github.com/ROCm/hipBLASLt/blob/b2adca84509dd31e31b8f42044389128d199b62e/library/include/hipblaslt.h#L67

Which looks like it is to enable the non-fnuz types:

https://github.com/ROCm/hipBLASLt/blob/b2adca84509dd31e31b8f42044389128d199b62e/library/src/amd_detail/rocblaslt/src/utility.cpp#L79

So we can probably use #ifdef ROCM_USE_FLOAT8 for this.

@CharlieL7
Copy link
Collaborator Author

Works well with the variable

@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
519a63
Rate old
c51bea
Diff Compare
torchvision-resnet50 64 3,260.35 3,257.81 0.08%
torchvision-resnet50_fp16 64 6,992.56 6,987.81 0.07%
torchvision-densenet121 32 2,435.37 2,434.57 0.03%
torchvision-densenet121_fp16 32 4,052.07 4,065.61 -0.33%
torchvision-inceptionv3 32 1,628.86 1,637.17 -0.51%
torchvision-inceptionv3_fp16 32 2,745.48 2,759.26 -0.50%
cadene-inceptionv4 16 766.04 776.31 -1.32%
cadene-resnext64x4 16 810.66 811.75 -0.13%
slim-mobilenet 64 7,468.29 7,533.16 -0.86%
slim-nasnetalarge 64 208.48 211.39 -1.38%
slim-resnet50v2 64 3,440.40 3,504.83 -1.84%
bert-mrpc-onnx 8 1,151.09 1,146.47 0.40%
bert-mrpc-tf 1 465.75 473.89 -1.72%
pytorch-examples-wlang-gru 1 419.90 425.31 -1.27%
pytorch-examples-wlang-lstm 1 379.35 408.68 -7.18% 🔴
torchvision-resnet50_1 1 820.74 771.75 6.35% 🔆
cadene-dpn92_1 1 400.14 399.01 0.28%
cadene-resnext101_1 1 382.79 383.85 -0.28%
onnx-taau-downsample 1 346.22 343.09 0.91%
dlrm-criteoterabyte 1 33.35 33.31 0.12%
dlrm-criteoterabyte_fp16 1 52.72 52.71 0.02%
agentmodel 1 8,340.77 8,235.67 1.28%
unet_fp16 2 58.82 58.90 -0.14%
resnet50v1_fp16 1 917.54 940.89 -2.48%
resnet50v1_int8 1 1,035.41 1,025.93 0.92%
bert_base_cased_fp16 64 1,170.28 1,170.88 -0.05%
bert_large_uncased_fp16 32 363.71 363.69 0.01%
bert_large_fp16 1 198.93 200.14 -0.60%
distilgpt2_fp16 16 2,199.74 2,200.77 -0.05%
yolov5s 1 534.32 535.15 -0.16%
tinyllama 1 43.67 43.41 0.59%
vicuna-fastchat 1 171.49 178.09 -3.71% 🔴
whisper-tiny-encoder 1 418.42 418.18 0.06%
whisper-tiny-decoder 1 430.87 427.58 0.77%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@causten causten merged commit 495d3eb into develop Nov 13, 2024
36 of 43 checks passed
@causten causten deleted the ocp_fp8_hipblaslt branch November 13, 2024 15:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
FP8 issues related to FP8 implemenation simple small or simple changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update OCP FP8 support with hipblaslt support
6 participants