diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index f5ec898d8d5..4e282cc01ff 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -70,8 +70,14 @@ hipDataType get_type_hipblas(shape::type_t type) case shape::int32_type: return HIP_R_32I; case shape::uint32_type: return HIP_R_32U; case shape::fp8e4m3fnuz_type: return HIP_R_8F_E4M3_FNUZ; +// TODO can remove this preprocessor conditional when hip verison defaults to have these types +#ifdef ROCM_USE_FLOAT8 + case shape::fp8e4m3fn_type: return HIP_R_8F_E4M3; + case shape::fp8e5m2_type: return HIP_R_8F_E5M2; +#else case shape::fp8e4m3fn_type: case shape::fp8e5m2_type: +#endif case shape::tuple_type: case shape::bool_type: case shape::uint16_type: diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index db76c5a24ac..fa01c514c0c 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -81,6 +81,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) #ifndef _WIN32 MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_CK) #endif +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPBLASLT_GEMM) std::vector target::get_passes(migraphx::context& gctx, const compile_options& options) const { @@ -129,9 +130,12 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_fp8e4m3fnuz_ops.insert("argmin"); std::set unsupported_fp8ocp_ops = {}; - // TODO update with hipBLASLt support - unsupported_fp8ocp_ops.insert("dot"); - unsupported_fp8ocp_ops.insert("quant_dot"); + // TODO: remove this when the flag is removed + if(not enabled(MIGRAPHX_ENABLE_HIPBLASLT_GEMM{})) + { + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); + } #if MIGRAPHX_USE_MIOPEN // MIOpen doesn't have support for fp8 pooling yet. unsupported_fp8ocp_ops.insert("pooling"); @@ -140,6 +144,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti { unsupported_fp8ocp_ops.insert("convolution"); unsupported_fp8ocp_ops.insert("quant_convolution"); + unsupported_fp8ocp_ops.insert("dot"); + unsupported_fp8ocp_ops.insert("quant_dot"); } // add all device kernels unsupported_fp8ocp_ops.insert("logsoftmax");