Skip to content

Commit

Permalink
added bloat16 case
Browse files Browse the repository at this point in the history
	modified:   ../paddle/phi/kernels/funcs/blas/blas_impl.cu.h
	modified:   ../paddle/phi/kernels/impl/baddbmm_kernel_impl.h
  • Loading branch information
Qin-sx committed Jan 10, 2025
1 parent a0ac3ee commit fdf6322
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 2 deletions.
140 changes: 140 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,75 @@ inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
float alpha,
const phi::dtype::bfloat16 *A,
const phi::dtype::bfloat16 *B,
float beta,
phi::dtype::bfloat16 *C) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;

PADDLE_ENFORCE_GE(
context_.GetComputeCapability(),
80,
common::errors::InvalidArgument(
"cublas bf16 gemm requires GPU compute capability >= 80,"
"but received %d",
context_.GetComputeCapability()));

float h_alpha = alpha;
float h_beta = beta;

cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
A,
CUDA_R_16BF,
lda,
&h_beta,
C,
CUDA_R_16BF,
N,
CUDA_R_32F,
algo));
});
#else
// raise error
PADDLE_THROW(common::errors::Unimplemented(
"cublasGemmEx with bfloat16 is not supported on cuda <= 11"));

#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<phi::GPUContext>::GEMM(CBLAS_TRANSPOSE transA,
Expand Down Expand Up @@ -2117,6 +2186,77 @@ inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
#endif // CUDA_VERSION >= 11000
}

template <>
template <>
inline void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
CBLAS_TRANSPOSE transB,
int M,
int N,
int K,
float alpha,
const phi::dtype::bfloat16 *A,
const phi::dtype::bfloat16 *B,
float beta,
phi::dtype::bfloat16 *C,
int batchCount,
int64_t strideA,
int64_t strideB) const {
#if CUDA_VERSION >= 11000
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
int ldc = N;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
const int64_t strideC = M * N;

float h_alpha = alpha;
float h_beta = beta;

cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
bool use_tensor_op_math = context_.tensor_core_available();
if (use_tensor_op_math) {
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
}
VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False");

context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasGemmStridedBatchedEx(handle,
cuTransB,
cuTransA,
N,
M,
K,
&h_alpha,
B,
CUDA_R_16BF,
ldb,
strideB,
A,
CUDA_R_16BF,
lda,
strideA,
&h_beta,
C,
CUDA_R_16BF,
ldc,
strideC,
batchCount,
CUBLAS_COMPUTE_32F,
algo));
});
#else
// raise error
PADDLE_THROW(common::errors::Unimplemented(
"cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= "
"11"));
#endif // CUDA_VERSION >= 11000
}

template <>
template <typename T>
void Blas<phi::GPUContext>::BatchedGEMM(CBLAS_TRANSPOSE transA,
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/impl/baddbmm_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ limitations under the License. */

#include "glog/logging.h"

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/baddbmm_kernel.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

namespace phi {

template <typename T,
Expand Down Expand Up @@ -113,8 +115,12 @@ void BaddbmmKernel(const Context& dev_ctx,
funcs::EigenBroadcast<std::decay_t<decltype(place)>, T, 3>::Eval(
place, eigen_out, eigen_input, bcast_dims);

// special case for float16
if constexpr (std::is_same_v<T, phi::dtype::float16>) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

// special case for MPType
if constexpr (std::is_same_v<MPType, float>) {
VLOG(4) << "Function: baddbmm, Type of T: " << typeid(T).name();
VLOG(4) << "Function: baddbmm, Type of MPType: " << typeid(MPType).name();
float t_alpha = alpha;
float t_beta = beta;
if (x_dims[0] == 1) {
Expand Down

0 comments on commit fdf6322

Please sign in to comment.