Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Nov 6, 2023
1 parent 587dd2d commit bfb87a1
Showing 1 changed file with 79 additions and 23 deletions.
102 changes: 79 additions & 23 deletions cupy_backends/cupy_lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

#elif defined(CUPY_USE_HIP) // #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP)

#include "hip/cupy_rocsolver.h"
#include "hip/cupy_hipsolver.h"

#else // #if !defined(CUPY_NO_CUDA) && !defined(CUPY_USE_HIP)

Expand Down Expand Up @@ -120,6 +120,53 @@ int geqrf_loop(
return status;
}

/*
* loop-based batched orgqr (used on CUDA)
*/
template<typename T>
using orgqr = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, int, T*, int, const T*, T*, int, int*);

template<typename T> struct orgqr_func { orgqr<T> ptr; };
template<> struct orgqr_func<float> { orgqr<float> ptr = cusolverDnSorgqr; };
template<> struct orgqr_func<double> { orgqr<double> ptr = cusolverDnDorgqr; };
template<> struct orgqr_func<cuComplex> { orgqr<cuComplex> ptr = cusolverDnCungqr; };
template<> struct orgqr_func<cuDoubleComplex> { orgqr<cuDoubleComplex> ptr = cusolverDnZungqr; };

template<typename T>
int orgqr_loop(
intptr_t handle, int m, int n, int k, intptr_t a_ptr, int lda,
intptr_t tau_ptr, intptr_t w_ptr,
int buffersize, intptr_t info_ptr,
int batch_size, int origin_n) {
/*
* Assumptions:
* 1. the stream is set prior to calling this function
* 2. the workspace is reused in the loop
*/

cusolverStatus_t status;
T* A = reinterpret_cast<T*>(a_ptr);
const T* Tau = reinterpret_cast<const T*>(tau_ptr);
T* Work = reinterpret_cast<T*>(w_ptr);
int* devInfo = reinterpret_cast<int*>(info_ptr);

// we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only,
// so we use custom traits instead
orgqr<T> func = orgqr_func<T>().ptr;

for (int i=0; i<batch_size; i++) {
status = func(reinterpret_cast<cusolverDnHandle_t>(handle),
m, n, k, A, lda, Tau, Work, buffersize, devInfo);
if (status != 0) break;
A += m * origin_n;
Tau += k;
devInfo += 1;
}

return status;
}


#else

template<typename T>
Expand All @@ -137,14 +184,14 @@ int gesvd_loop(
* batched geqrf (only used on HIP)
*/
template<typename T>
using geqrf = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, T* const[], int, T*, long int, int);
using geqrf = hipsolverStatus_t (*)(hipsolverDnHandle_t, int, int, T*, int, T*, T*, int, int*);

template<typename T> struct geqrf_func { geqrf<T> ptr; };
template<> struct geqrf_func<float> { geqrf<float> ptr = rocsolver_sgeqrf_batched; };
template<> struct geqrf_func<double> { geqrf<double> ptr = rocsolver_dgeqrf_batched; };
template<> struct geqrf_func<float> { geqrf<float> ptr = hipsolverSgeqrf; };
template<> struct geqrf_func<double> { geqrf<double> ptr = hipsolverDgeqrf; };
// we need the correct func pointer here, so can't cast!
template<> struct geqrf_func<rocblas_float_complex> { geqrf<rocblas_float_complex> ptr = rocsolver_cgeqrf_batched; };
template<> struct geqrf_func<rocblas_double_complex> { geqrf<rocblas_double_complex> ptr = rocsolver_zgeqrf_batched; };
template<> struct geqrf_func<hipFloatComplex> { geqrf<hipFloatComplex> ptr = hipsolverCgeqrf; };
template<> struct geqrf_func<hipDoubleComplex> { geqrf<hipDoubleComplex> ptr = hipsolverZgeqrf; };

template<typename T>
int geqrf_loop(
Expand All @@ -158,41 +205,47 @@ int geqrf_loop(
* 2. ignore w_ptr, buffersize, and info_ptr as rocSOLVER does not need them
*/

cusolverStatus_t status;
hipsolverStatus_t status;

// we can't use "if constexpr" to do a compile-time branch selection as it's C++17 only,
// so we use custom traits instead
typedef typename std::conditional<
std::is_floating_point<T>::value,
T,
typename std::conditional<std::is_same<T, cuComplex>::value,
rocblas_float_complex,
rocblas_double_complex>::type
typename std::conditional<std::is_same<T, hipFloatComplex>::value,
hipFloatComplex,
hipDoubleComplex>::type
>::type data_type;
geqrf<data_type> func = geqrf_func<data_type>().ptr;
data_type* const* A = reinterpret_cast<data_type* const*>(a_ptr);
data_type* A = reinterpret_cast<data_type*>(a_ptr);
data_type* Tau = reinterpret_cast<data_type*>(tau_ptr);
int k = (m<n)?m:n;

// use rocSOLVER's batched geqrf
status = func((cusolverDnHandle_t)handle, m, n, A, lda, Tau, k, batch_size);
data_type* Work = reinterpret_cast<data_type*>(w_ptr);
int* devInfo = reinterpret_cast<int*>(info_ptr);
for (int i=0; i < batch_size; i++) {
status = func(reinterpret_cast<hipsolverDnHandle_t>(handle),
m, n, A, lda, Tau, Work, buffersize, devInfo);
if (status != 0) break;
A += m * n;
Tau += k;
devInfo += 1;
}

return status;
}
#endif // #if !defined(CUPY_USE_HIP)


/*
* loop-based batched orgqr (used on both CUDA & HIP)
* loop-based batched orgqr (used on HIP)
*/
template<typename T>
using orgqr = cusolverStatus_t (*)(cusolverDnHandle_t, int, int, int, T*, int, const T*, T*, int, int*);
using orgqr = hipsolverStatus_t (*)(hipsolverDnHandle_t, int, int, int, T*, int, const T*, T*, int, int*);

template<typename T> struct orgqr_func { orgqr<T> ptr; };
template<> struct orgqr_func<float> { orgqr<float> ptr = cusolverDnSorgqr; };
template<> struct orgqr_func<double> { orgqr<double> ptr = cusolverDnDorgqr; };
template<> struct orgqr_func<cuComplex> { orgqr<cuComplex> ptr = cusolverDnCungqr; };
template<> struct orgqr_func<cuDoubleComplex> { orgqr<cuDoubleComplex> ptr = cusolverDnZungqr; };
template<> struct orgqr_func<float> { orgqr<float> ptr = hipsolverDnSorgqr; };
template<> struct orgqr_func<double> { orgqr<double> ptr = hipsolverDnDorgqr; };
template<> struct orgqr_func<hipFloatComplex> { orgqr<hipFloatComplex> ptr = hipsolverDnCungqr; };
template<> struct orgqr_func<hipDoubleComplex> { orgqr<hipDoubleComplex> ptr = hipsolverDnZungqr; };

template<typename T>
int orgqr_loop(
Expand All @@ -206,7 +259,7 @@ int orgqr_loop(
* 2. the workspace is reused in the loop
*/

cusolverStatus_t status;
hipsolverStatus_t status;
T* A = reinterpret_cast<T*>(a_ptr);
const T* Tau = reinterpret_cast<const T*>(tau_ptr);
T* Work = reinterpret_cast<T*>(w_ptr);
Expand All @@ -217,7 +270,7 @@ int orgqr_loop(
orgqr<T> func = orgqr_func<T>().ptr;

for (int i=0; i<batch_size; i++) {
status = func(reinterpret_cast<cusolverDnHandle_t>(handle),
status = func(reinterpret_cast<hipsolverDnHandle_t>(handle),
m, n, k, A, lda, Tau, Work, buffersize, devInfo);
if (status != 0) break;
A += m * origin_n;
Expand All @@ -227,4 +280,7 @@ int orgqr_loop(

return status;
}

#endif // #if !defined(CUPY_USE_HIP)

#endif // #ifndef INCLUDE_GUARD_CUPY_CUSOLVER_H

0 comments on commit bfb87a1

Please sign in to comment.