Skip to content

Commit

Permalink
fix: preprocessor directives logic error if/else (ROCm#1764)
Browse files Browse the repository at this point in the history
* fix: preprocessors logic error if/else

* fix: added macros as preferred by CK team
  • Loading branch information
deepsek authored and kylasa committed Jan 19, 2025
1 parent 8048955 commit 93ecdd5
Showing 1 changed file with 47 additions and 57 deletions.
104 changes: 47 additions & 57 deletions profiler/src/profile_grouped_gemm_fixed_nk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ enum struct GemmDataType
F16_F16_F16, // 1
F16_F8_F16, // 2
F16_I8_F16, // 3

};

#define OP_NAME "grouped_gemm_fixed_nk"
Expand All @@ -39,7 +38,6 @@ std::vector<int> argToIntArray(char* input)
{
out.push_back(std::stoi(item));
}

return out;
}

Expand Down Expand Up @@ -83,14 +81,6 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc >= 15 ? std::stoi(argv[14]) : 1;

using F32 = float;
using F16 = ck::half_t;
#if defined(CK_ENABLE_FP8)
using F8 = ck::f8_t;
#endif
using BF16 = ck::bhalf_t;
using I8 = int8_t;

int n_warmup = 1;
int n_iter = 10;
if(argc == 17)
Expand All @@ -99,13 +89,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter = std::stoi(argv[16]);
}

#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -123,12 +112,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<BF16,
I8,
BF16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -146,14 +135,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16)
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#if defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -171,12 +159,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F16,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
ck::f8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -195,13 +183,13 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
int8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -219,12 +207,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
F8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::half_t,
int8_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -238,18 +226,19 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
kbatch,
1,
n_warmup,
n_iter);
}
#endif
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
#if defined(CK_ENABLE_BF16)
#if defined(CK_ENABLE_INT8)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -267,12 +256,12 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
n_warmup,
n_iter);
}
else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_fixed_nk_impl<F16,
I8,
F16,
F32,
ck::profiler::profile_grouped_gemm_fixed_nk_impl<ck::bhalf_t,
int8_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
Expand All @@ -286,10 +275,11 @@ int profile_grouped_gemm_fixed_nk(int argc, char* argv[])
StrideAs,
StrideBs,
StrideCs,
1,
kbatch,
n_warmup,
n_iter);
}
#endif
#endif
else
{
Expand Down

0 comments on commit 93ecdd5

Please sign in to comment.