Skip to content

Commit

Permalink
[ROCm] Fix Int_mm() Integration with hipblasLT (pytorch#122431)
Browse files Browse the repository at this point in the history
The PR

- fixes int_mm() /int8_gemm() integration with hipblasLT backend (require ROCm 6.0).
- enables/fixes the following tests on Rocm
    - test__int_mm_k_16_n_16_use_transpose_a_False_use_transpose_b_False_cuda
    - test__int_mm_k_16_n_16_use_transpose_a_False_use_transpose_b_True_cuda
    - test__int_mm_k_16_n_16_use_transpose_a_True_use_transpose_b_False_cuda
    - test__int_mm_k_16_n_16_use_transpose_a_True_use_transpose_b_True_cuda
    - test__int_mm_k_16_n_32_use_transpose_a_False_use_transpose_b_False_cuda
    - test__int_mm_k_16_n_32_use_transpose_a_False_use_transpose_b_True_cuda
    - test__int_mm_k_16_n_32_use_transpose_a_True_use_transpose_b_False_cuda
    - test__int_mm_k_16_n_32_use_transpose_a_True_use_transpose_b_True_cuda
    - test__int_mm_k_32_n_16_use_transpose_a_False_use_transpose_b_False_cuda
    - test__int_mm_k_32_n_16_use_transpose_a_False_use_transpose_b_True_cuda
    - test__int_mm_k_32_n_16_use_transpose_a_True_use_transpose_b_False_cuda
    - test__int_mm_k_32_n_16_use_transpose_a_True_use_transpose_b_True_cuda
    - test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_False_cuda
    - test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_True_cuda
    - test__int_mm_k_32_n_32_use_transpose_a_True_use_transpose_b_False_cuda
    - test__int_mm_k_32_n_32_use_transpose_a_True_use_transpose_b_True_cuda

Pull Request resolved: pytorch#122431
Approved by: https://github.com/pruthvistony, https://github.com/jithunnair-amd, https://github.com/malfet, https://github.com/atalman
  • Loading branch information
petrex authored and pytorchmergebot committed Apr 24, 2024
1 parent f0f7452 commit 2e7b8ff
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
39 changes: 37 additions & 2 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,11 +1655,34 @@ void int8_gemm(
CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2);
CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld);

cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();

// cublas team: alpha and beta need to be the same dtype as of scaleType
at::opmath_type<int32_t> alpha_val = 1;
int32_t beta_val = 0;
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();

#ifdef USE_ROCM
CuBlasLtMatmulPreference preference;
size_t workspaceSize = _getWorkspaceSize();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto workspace = allocator.allocate(workspaceSize);
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED);
}
#endif

cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
Expand All @@ -1674,9 +1697,21 @@ void int8_gemm(
Cdesc.descriptor(),
result_ptr,
Cdesc.descriptor(),
#ifdef USE_ROCM
&heuristicResult.algo,
#else
nullptr, // Heuristics don't seem to work for int8
#endif
#ifdef USE_ROCM
workspace.mutable_get(),
#else
nullptr, // Non-zero workspace doesn't seem to work.
#endif
#ifdef USE_ROCM
workspaceSize,
#else
0,
#endif
at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)

TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");

#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070
#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
cublasCommonArgs args(self, mat2, result);

at::cuda::blas::int8_gemm(
Expand Down
10 changes: 5 additions & 5 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5794,17 +5794,14 @@ def test_matmul_45724(self, device):
self.assertEqual(c, cpu_result)

@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(SM90OrLater, "Expected failure on sm90")
@unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
@parametrize("k", [16, 32])
@parametrize("n", [16, 32])
@parametrize("use_transpose_a", [True, False])
@parametrize("use_transpose_b", [True, False])
def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b):
if TEST_WITH_ROCM:
self.skipTest("_int_mm not compiled for ROCM")

def genf_int_float(x, y, use_transpose):
if use_transpose:
x, y = y, x
Expand Down Expand Up @@ -5837,7 +5834,10 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True):
SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0)
SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0)
SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5)
if version >= (11, 7):

if TEST_WITH_ROCM:
_test(17, k, n, use_transpose_a, use_transpose_b, True)
elif version >= (11, 7):
if not use_transpose_a and use_transpose_b:
if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)):
_test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))
Expand Down

0 comments on commit 2e7b8ff

Please sign in to comment.