From 255c38eda353d3470d2441b3e43fa306499ef642 Mon Sep 17 00:00:00 2001 From: Bibek Ghimire Date: Wed, 11 Dec 2024 09:16:17 +0000 Subject: [PATCH] add ck bfp16 for fwd and bwd --- .../conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp | 12 ++++++++---- .../conv/conv_hip_implicit_gemm_fwd_xdlops.cpp | 12 ++++++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/solver/conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp b/src/solver/conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp index 703dc2b7a9..6d42e227ba 100644 --- a/src/solver/conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp +++ b/src/solver/conv/conv_hip_implicit_gemm_bwd_data_xdlops.cpp @@ -181,12 +181,12 @@ void PerformanceConfigHipImplicitGemmBwdXdlops::HeuristicInit( { case miopenHalf: Init(problem); break; case miopenFloat: Init(problem); break; + case miopenBFloat16: Init(problem); break; case miopenFloat8: case miopenBFloat8: case miopenInt8: case miopenInt32: case miopenInt64: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -223,12 +223,12 @@ bool PerformanceConfigHipImplicitGemmBwdXdlops::IsValid( { case miopenHalf: return CheckIsSupportCKArgs(problem); case miopenFloat: return CheckIsSupportCKArgs(problem); + case miopenBFloat16: return CheckIsSupportCKArgs(problem); case miopenFloat8: case miopenBFloat8: case miopenInt8: case miopenInt32: case miopenInt64: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -304,12 +304,12 @@ bool ConvHipImplicitGemmBwdXdlops::IsApplicable( { case miopenHalf: return CheckCKApplicability(problem); case miopenFloat: return CheckCKApplicability(problem); + case miopenBFloat16: return CheckCKApplicability(problem); case miopenFloat8: case miopenBFloat8: case miopenInt8: case miopenInt32: case miopenInt64: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -334,10 +334,14 @@ ConvSolution ConvHipImplicitGemmBwdXdlops::GetSolution( CKArgs, miopen::conv::DataInvokeParams>( ctx, problem, config.kernel_id); + case miopenBFloat16: + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenInt8: case miopenInt32: case miopenInt64: - case miopenBFloat16: case miopenDouble: case miopenFloat8: case miopenBFloat8: diff --git a/src/solver/conv/conv_hip_implicit_gemm_fwd_xdlops.cpp b/src/solver/conv/conv_hip_implicit_gemm_fwd_xdlops.cpp index 84fba1e862..0e1bb5bca8 100644 --- a/src/solver/conv/conv_hip_implicit_gemm_fwd_xdlops.cpp +++ b/src/solver/conv/conv_hip_implicit_gemm_fwd_xdlops.cpp @@ -182,11 +182,11 @@ void PerformanceConfigHipImplicitGemmFwdXdlops::HeuristicInit( case miopenInt8: Init(problem); break; case miopenHalf: Init(problem); break; case miopenFloat: Init(problem); break; + case miopenBFloat16: Init(problem); break; case miopenFloat8: case miopenBFloat8: case miopenInt64: case miopenInt32: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -225,11 +225,11 @@ bool PerformanceConfigHipImplicitGemmFwdXdlops::IsValid( case miopenInt8: return CheckIsSupportCKArgs(problem); case miopenHalf: return CheckIsSupportCKArgs(problem); case miopenFloat: return CheckIsSupportCKArgs(problem); + case miopenBFloat16: return CheckIsSupportCKArgs(problem); case miopenFloat8: case miopenBFloat8: case miopenInt64: case miopenInt32: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -306,11 +306,11 @@ bool ConvHipImplicitGemmFwdXdlops::IsApplicable( case miopenInt8: return CheckCKApplicability(problem); case miopenHalf: return CheckCKApplicability(problem); case miopenFloat: return CheckCKApplicability(problem); + case miopenBFloat16: return CheckCKApplicability(problem); case miopenFloat8: case miopenBFloat8: case miopenInt64: case miopenInt32: - case miopenBFloat16: case miopenDouble: break; } #endif @@ -336,9 +336,13 @@ ConvSolution ConvHipImplicitGemmFwdXdlops::GetSolution( case miopenFloat: return InitInvokerFactoryNHWC, CKArgs, miopen::conv::DataInvokeParams>( ctx, problem, config.kernel_id); + case miopenBFloat16: + return InitInvokerFactoryNHWC, + CKArgs, + miopen::conv::DataInvokeParams>( + ctx, problem, config.kernel_id); case miopenInt64: case miopenInt32: - case miopenBFloat16: case miopenDouble: case miopenFloat8: case miopenBFloat8: