From 2e7b8ff116c10d4d3cdf201e62802675042fda5f Mon Sep 17 00:00:00 2001 From: Peter Y Yeh Date: Wed, 24 Apr 2024 02:29:28 +0000 Subject: [PATCH] [ROCm] Fix Int_mm() Integration with hipblasLT (#122431) The PR - fixes int_mm() /int8_gemm() integration with hipblasLT backend (require ROCm 6.0). - enables/fixes the following tests on Rocm - test__int_mm_k_16_n_16_use_transpose_a_False_use_transpose_b_False_cuda - test__int_mm_k_16_n_16_use_transpose_a_False_use_transpose_b_True_cuda - test__int_mm_k_16_n_16_use_transpose_a_True_use_transpose_b_False_cuda - test__int_mm_k_16_n_16_use_transpose_a_True_use_transpose_b_True_cuda - test__int_mm_k_16_n_32_use_transpose_a_False_use_transpose_b_False_cuda - test__int_mm_k_16_n_32_use_transpose_a_False_use_transpose_b_True_cuda - test__int_mm_k_16_n_32_use_transpose_a_True_use_transpose_b_False_cuda - test__int_mm_k_16_n_32_use_transpose_a_True_use_transpose_b_True_cuda - test__int_mm_k_32_n_16_use_transpose_a_False_use_transpose_b_False_cuda - test__int_mm_k_32_n_16_use_transpose_a_False_use_transpose_b_True_cuda - test__int_mm_k_32_n_16_use_transpose_a_True_use_transpose_b_False_cuda - test__int_mm_k_32_n_16_use_transpose_a_True_use_transpose_b_True_cuda - test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_False_cuda - test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_True_cuda - test__int_mm_k_32_n_32_use_transpose_a_True_use_transpose_b_False_cuda - test__int_mm_k_32_n_32_use_transpose_a_True_use_transpose_b_True_cuda Pull Request resolved: https://github.com/pytorch/pytorch/pull/122431 Approved by: https://github.com/pruthvistony, https://github.com/jithunnair-amd, https://github.com/malfet, https://github.com/atalman --- aten/src/ATen/cuda/CUDABlas.cpp | 39 ++++++++++++++++++++++++++++-- aten/src/ATen/native/cuda/Blas.cpp | 2 +- test/test_linalg.py | 10 ++++---- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index f9ac77b53e138..c211092c4998a 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1655,11 +1655,34 @@ void int8_gemm( CuBlasLtMatrixLayout Bdesc(abType, k, n, mat2_ld, transpose_mat2); CuBlasLtMatrixLayout Cdesc(cType, m, n, result_ld); - cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); - // cublas team: alpha and beta need to be the same dtype as of scaleType at::opmath_type alpha_val = 1; int32_t beta_val = 0; + cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + +#ifdef USE_ROCM + CuBlasLtMatmulPreference preference; + size_t workspaceSize = _getWorkspaceSize(); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto workspace = allocator.allocate(workspaceSize); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); + } +#endif cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, @@ -1674,9 +1697,21 @@ void int8_gemm( Cdesc.descriptor(), result_ptr, Cdesc.descriptor(), +#ifdef USE_ROCM + &heuristicResult.algo, +#else nullptr, // Heuristics don't seem to work for int8 +#endif +#ifdef USE_ROCM + workspace.mutable_get(), +#else nullptr, // Non-zero workspace doesn't seem to work. +#endif +#ifdef USE_ROCM + workspaceSize, +#else 0, +#endif at::cuda::getCurrentCUDAStream()); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 7195f939f746a..060eb7408b0be 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -747,7 +747,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous."); -#if !defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070 +#if (!defined(USE_ROCM) && !defined(_MSC_VER) && defined(CUDA_VERSION) && CUDA_VERSION >= 11070) || (defined(USE_ROCM) && ROCM_VERSION >= 60000) cublasCommonArgs args(self, mat2, result); at::cuda::blas::int8_gemm( diff --git a/test/test_linalg.py b/test/test_linalg.py index 5053b5e8fe167..e22dabcf56e85 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5794,7 +5794,7 @@ def test_matmul_45724(self, device): self.assertEqual(c, cpu_result) @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") - @unittest.skipIf(SM90OrLater, "Expected failure on sm90") + @unittest.skipIf(SM90OrLater and not TEST_WITH_ROCM, "Expected failure on sm90") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA @parametrize("k", [16, 32]) @@ -5802,9 +5802,6 @@ def test_matmul_45724(self, device): @parametrize("use_transpose_a", [True, False]) @parametrize("use_transpose_b", [True, False]) def test__int_mm(self, device, k, n, use_transpose_a, use_transpose_b): - if TEST_WITH_ROCM: - self.skipTest("_int_mm not compiled for ROCM") - def genf_int_float(x, y, use_transpose): if use_transpose: x, y = y, x @@ -5837,7 +5834,10 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) SM70 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 0) SM75 = torch.cuda.is_available() and torch.cuda.get_device_capability() == (7, 5) - if version >= (11, 7): + + if TEST_WITH_ROCM: + _test(17, k, n, use_transpose_a, use_transpose_b, True) + elif version >= (11, 7): if not use_transpose_a and use_transpose_b: if SM80OrLater or (version >= (12, 3) and (SM70 or SM75)): _test(17, k, n, use_transpose_a, use_transpose_b, version > (11, 7))