diff --git a/cupy_backends/cuda/libs/cublas.pyx b/cupy_backends/cuda/libs/cublas.pyx index 03484afcd22..bbd43a452c8 100644 --- a/cupy_backends/cuda/libs/cublas.pyx +++ b/cupy_backends/cuda/libs/cublas.pyx @@ -487,6 +487,43 @@ ELSE: if status != 0: raise CUBLASError(status) + ############################################################################### + # Enum Conversion + ############################################################################### + cpdef int convert_blas_dtype(int dtype) nogil: + IF CUPY_HIP_VERSION < 60000000: + if dtype == 0: # CUDA_R_32F + return 151 # HIPBLAS_R_32F + elif dtype == 1: # CUDA_R_64F + return 152 # HIPBLAS_R_64F + elif dtype == 2: # CUDA_R_16F + return 150 # HIPBLAS_R_16F + elif dtype == 3: # CUDA_R_8I + return 160 # HIPBLAS_R_8I + elif dtype == 4: # CUDA_C_32F + return 154 # HIPBLAS_C_32F + elif dtype == 5: # CUDA_C_64F + return 155 # HIPBLAS_C_64F + elif dtype == 6: # CUDA_C_16F + return 153 # HIPBLAS_C_16F + return dtype + + cpdef int convert_blas_computetype(int ctype) nogil: + # hipblasComputeType_t is supported in V2 API + IF CUPY_HIP_VERSION >= 60000000: + if ctype == 2: # CUDA_R_16F + return 0 # HIPBLAS_COMPUTE_16F + elif ctype == 0: # CUDA_R_32F + return 2 # HIPBLAS_COMPUTE_32F + elif ctype == 1: # CUDA_R_64F + return 7 # HIPBLAS_COMPUTE_64F + elif ctype == 4: # CUDA_C_32F + return 2 # HIPBLAS_COMPUTE_32F + elif ctype == 5: # CUDA_C_64F + return 7 # HIPBLAS_COMPUTE_64F + elif ctype == 6: # CUDA_C_16F + return 0 # HIPBLAS_COMPUTE_16F + return convert_blas_dtype(ctype) ############################################################################### # Context @@ -1336,9 +1373,9 @@ ELSE: with nogil: status = cublasSgemmEx( handle, transa, transb, m, n, k, - alpha, A, Atype, lda, - B, Btype, ldb, beta, - C, Ctype, ldc) + alpha, A, (convert_blas_dtype(Atype)), lda, + B, (convert_blas_dtype(Btype)), ldb, beta, + C, (convert_blas_dtype(Ctype)), ldc) check_status(status) @@ -1537,39 +1574,39 @@ ELSE: status = cublasGemmEx_v11( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, - B, Btype, ldb, + A, (convert_blas_dtype(Atype)), lda, + B, (convert_blas_dtype(Btype)), ldb, beta, - C, Ctype, ldc, - computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, + (convert_blas_computetype(computeType)), algo) ELSE: status = cublasGemmEx_v11( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, - B, Btype, ldb, + A, (convert_blas_dtype(Atype)), lda, + B, (convert_blas_dtype(Btype)), ldb, beta, - C, Ctype, ldc, - computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, + (convert_blas_computetype(computeType)), algo) else: IF 0 < CUPY_HIP_VERSION < 60000000: status = cublasGemmEx( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, - B, Btype, ldb, + A, (convert_blas_dtype(Atype)), lda, + B, (convert_blas_dtype(Btype)), ldb, beta, - C, Ctype, ldc, - computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, + (convert_blas_computetype(computeType)), algo) ELSE: status = cublasGemmEx( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, - B, Btype, ldb, + A, (convert_blas_dtype(Atype)), lda, + B, (convert_blas_dtype(Btype)), ldb, beta, - C, Ctype, ldc, - computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, + (convert_blas_computetype(computeType)), algo) check_status(status) @@ -1588,39 +1625,39 @@ ELSE: status = cublasGemmStridedBatchedEx_v11( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, strideA, - B, Btype, ldb, strideB, + A, (convert_blas_dtype(Atype)), lda, strideA, + B, (convert_blas_dtype(Btype)), ldb, strideB, beta, - C, Ctype, ldc, strideC, - batchCount, computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, strideC, + batchCount, (convert_blas_computetype(computeType)), algo) ELSE: status = cublasGemmStridedBatchedEx_v11( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, strideA, - B, Btype, ldb, strideB, + A, (convert_blas_dtype(Atype)), lda, strideA, + B, (convert_blas_dtype(Btype)), ldb, strideB, beta, - C, Ctype, ldc, strideC, - batchCount, computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, strideC, + batchCount, (convert_blas_computetype(computeType)), algo) else: IF 0 < CUPY_HIP_VERSION < 60000000: status = cublasGemmStridedBatchedEx( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, strideA, - B, Btype, ldb, strideB, + A, (convert_blas_dtype(Atype)), lda, strideA, + B, (convert_blas_dtype(Btype)), ldb, strideB, beta, - C, Ctype, ldc, strideC, - batchCount, computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, strideC, + batchCount, (convert_blas_computetype(computeType)), algo) ELSE: status = cublasGemmStridedBatchedEx( handle, transa, transb, m, n, k, alpha, - A, Atype, lda, strideA, - B, Btype, ldb, strideB, + A, (convert_blas_dtype(Atype)), lda, strideA, + B, (convert_blas_dtype(Btype)), ldb, strideB, beta, - C, Ctype, ldc, strideC, - batchCount, computeType, algo) + C, (convert_blas_dtype(Ctype)), ldc, strideC, + batchCount, (convert_blas_computetype(computeType)), algo) check_status(status)