Skip to content

Commit

Permalink
Add dtype and compute type conversions in cublas.pyx
Browse files Browse the repository at this point in the history
  • Loading branch information
pnunna93 committed Feb 14, 2024
1 parent 6084c3a commit 01abc5a
Showing 1 changed file with 72 additions and 35 deletions.
107 changes: 72 additions & 35 deletions cupy_backends/cuda/libs/cublas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,43 @@ ELSE:
if status != 0:
raise CUBLASError(status)

###############################################################################
# Enum Conversion
###############################################################################
cpdef int convert_blas_dtype(int dtype) nogil:
IF CUPY_HIP_VERSION < 60000000:
if dtype == 0: # CUDA_R_32F
return 151 # HIPBLAS_R_32F
elif dtype == 1: # CUDA_R_64F
return 152 # HIPBLAS_R_64F
elif dtype == 2: # CUDA_R_16F
return 150 # HIPBLAS_R_16F
elif dtype == 3: # CUDA_R_8I
return 160 # HIPBLAS_R_8I
elif dtype == 4: # CUDA_C_32F
return 154 # HIPBLAS_C_32F
elif dtype == 5: # CUDA_C_64F
return 155 # HIPBLAS_C_64F
elif dtype == 6: # CUDA_C_16F
return 153 # HIPBLAS_C_16F
return dtype

cpdef int convert_blas_computetype(int ctype) nogil:
# hipblasComputeType_t is supported in V2 API
IF CUPY_HIP_VERSION >= 60000000:
if ctype == 2: # CUDA_R_16F
return 0 # HIPBLAS_COMPUTE_16F
elif ctype == 0: # CUDA_R_32F
return 2 # HIPBLAS_COMPUTE_32F
elif ctype == 1: # CUDA_R_64F
return 7 # HIPBLAS_COMPUTE_64F
elif ctype == 4: # CUDA_C_32F
return 2 # HIPBLAS_COMPUTE_32F
elif ctype == 5: # CUDA_C_64F
return 7 # HIPBLAS_COMPUTE_64F
elif ctype == 6: # CUDA_C_16F
return 0 # HIPBLAS_COMPUTE_16F
return convert_blas_dtype(ctype)

###############################################################################
# Context
Expand Down Expand Up @@ -1336,9 +1373,9 @@ ELSE:
with nogil:
status = cublasSgemmEx(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const float*>alpha, <const void*>A, <DataType>Atype, lda,
<const void*>B, <DataType>Btype, ldb, <const float*>beta,
<void*>C, <DataType>Ctype, ldc)
<const float*>alpha, <const void*>A, <DataType>(convert_blas_dtype(Atype)), lda,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb, <const float*>beta,
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc)
check_status(status)


Expand Down Expand Up @@ -1537,39 +1574,39 @@ ELSE:
status = cublasGemmEx_v11(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda,
<const void*>B, <DataType>Btype, ldb,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc,
<DataType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc,
<DataType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
ELSE:
status = cublasGemmEx_v11(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda,
<const void*>B, <DataType>Btype, ldb,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc,
<ComputeType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc,
<ComputeType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
else:
IF 0 < CUPY_HIP_VERSION < 60000000:
status = cublasGemmEx(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda,
<const void*>B, <DataType>Btype, ldb,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc,
<DataType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc,
<DataType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
ELSE:
status = cublasGemmEx(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda,
<const void*>B, <DataType>Btype, ldb,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc,
<ComputeType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc,
<ComputeType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
check_status(status)


Expand All @@ -1588,39 +1625,39 @@ ELSE:
status = cublasGemmStridedBatchedEx_v11(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda, <long long>strideA,
<const void*>B, <DataType>Btype, ldb, <long long>strideB,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda, <long long>strideA,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb, <long long>strideB,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc, <long long>strideC,
batchCount, <DataType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc, <long long>strideC,
batchCount, <DataType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
ELSE:
status = cublasGemmStridedBatchedEx_v11(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda, <long long>strideA,
<const void*>B, <DataType>Btype, ldb, <long long>strideB,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda, <long long>strideA,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb, <long long>strideB,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc, <long long>strideC,
batchCount, <ComputeType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc, <long long>strideC,
batchCount, <ComputeType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
else:
IF 0 < CUPY_HIP_VERSION < 60000000:
status = cublasGemmStridedBatchedEx(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda, <long long>strideA,
<const void*>B, <DataType>Btype, ldb, <long long>strideB,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda, <long long>strideA,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb, <long long>strideB,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc, <long long>strideC,
batchCount, <DataType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc, <long long>strideC,
batchCount, <DataType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
ELSE:
status = cublasGemmStridedBatchedEx(
<Handle>handle, <Operation>transa, <Operation>transb, m, n, k,
<const void*>alpha,
<const void*>A, <DataType>Atype, lda, <long long>strideA,
<const void*>B, <DataType>Btype, ldb, <long long>strideB,
<const void*>A, <DataType>(convert_blas_dtype(Atype)), lda, <long long>strideA,
<const void*>B, <DataType>(convert_blas_dtype(Btype)), ldb, <long long>strideB,
<const void*>beta,
<void*>C, <DataType>Ctype, ldc, <long long>strideC,
batchCount, <ComputeType>computeType, <GemmAlgo>algo)
<void*>C, <DataType>(convert_blas_dtype(Ctype)), ldc, <long long>strideC,
batchCount, <ComputeType>(convert_blas_computetype(computeType)), <GemmAlgo>algo)
check_status(status)


Expand Down

0 comments on commit 01abc5a

Please sign in to comment.