Skip to content

Commit

Permalink
Remove _OVERLOAD layer
Browse files Browse the repository at this point in the history
  • Loading branch information
samhatfield committed Feb 6, 2025
1 parent 9d5d769 commit dfaa28f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 255 deletions.
233 changes: 0 additions & 233 deletions src/trans/gpu/algor/hicblas_mod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,8 @@
! nor does it submit to any jurisdiction.
!

#if defined CUDAGPU
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
#define OPENACC_LIB OPENACC
#endif

MODULE HICBLAS_MOD

USE EC_PARKIND, ONLY: JPIM, JPRM, JPRD, JPIB
USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE
#ifdef ACCGPU
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
#endif
#ifdef OMPGPU
#endif

IMPLICIT NONE

INTERFACE
Expand Down Expand Up @@ -118,224 +105,4 @@ SUBROUTINE HIP_SGEMM_GROUPED( &
END SUBROUTINE HIP_SGEMM_GROUPED
END INTERFACE

CONTAINS

SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( &
& TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, STRIDEA, &
& BARRAY, LDB, STRIDEB, &
& BETA, &
& CARRAY, LDC, STRIDEC, &
& BATCHCOUNT, STREAM, ALLOC)
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
INTEGER(KIND=JPIM) :: M
INTEGER(KIND=JPIM) :: N
INTEGER(KIND=JPIM) :: K
REAL(KIND=JPRD) :: ALPHA
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
INTEGER(KIND=JPIM) :: LDA
INTEGER(KIND=JPIM) :: STRIDEA
REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY
INTEGER(KIND=JPIM) :: LDB
INTEGER(KIND=JPIM) :: STRIDEB
REAL(KIND=JPRD) :: BETA
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
INTEGER(KIND=JPIM) :: LDC
INTEGER(KIND=JPIM) :: STRIDEC
INTEGER(KIND=JPIM) :: BATCHCOUNT
INTEGER(KIND=C_INT) :: STREAM
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC

INTEGER(KIND=C_LONG) :: HIP_STREAM

#ifdef ACCGPU
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
#endif
#ifdef OMPGPU
#endif

#if defined(_CRAYFTN)
#ifdef ACCGPU
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
#endif
CALL HIP_DGEMM_BATCHED( &
& TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, STRIDEA, &
& BARRAY, LDB, STRIDEB, &
& BETA, &
& CARRAY, LDC, STRIDEC, &
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
#ifdef ACCGPU
!$ACC END HOST_DATA
#endif
#endif
END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD

SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( &
& TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, STRIDEA, &
& BARRAY, LDB, STRIDEB, &
& BETA, &
& CARRAY, LDC, STRIDEC, &
& BATCHCOUNT, STREAM, ALLOC)
USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
INTEGER(KIND=JPIM) :: M
INTEGER(KIND=JPIM) :: N
INTEGER(KIND=JPIM) :: K
REAL(KIND=JPRM) :: ALPHA
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
INTEGER(KIND=JPIM) :: LDA
INTEGER(KIND=JPIM) :: STRIDEA
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
INTEGER(KIND=JPIM) :: LDB
INTEGER(KIND=JPIM) :: STRIDEB
REAL(KIND=JPRM) :: BETA
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
INTEGER(KIND=JPIM) :: LDC
INTEGER(KIND=JPIM) :: STRIDEC
INTEGER(KIND=JPIM) :: BATCHCOUNT
INTEGER(KIND=C_INT) :: STREAM
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC

INTEGER(KIND=C_LONG) :: HIP_STREAM

#ifdef ACCGPU
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
#endif
#ifdef OMPGPU
#endif

CALL HIP_SGEMM_BATCHED( &
& TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, STRIDEA, &
& BARRAY, LDB, STRIDEB, &
& BETA, &
& CARRAY, LDC, STRIDEC, &
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD

SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, OFFSETA, &
& BARRAY, LDB, OFFSETB, &
& BETA, &
& CARRAY, LDC, OFFSETC, &
& BATCHCOUNT, STREAM, ALLOC)
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
INTEGER(KIND=JPIM) :: M
INTEGER(KIND=JPIM) :: N(:)
INTEGER(KIND=JPIM) :: K(:)
REAL(KIND=JPRD) :: ALPHA
REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
INTEGER(KIND=JPIM) :: LDA
INTEGER(KIND=JPIB) :: OFFSETA(:)
REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
INTEGER(KIND=JPIM) :: LDB(:)
INTEGER(KIND=JPIB) :: OFFSETB(:)
REAL(KIND=JPRD) :: BETA
REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
INTEGER(KIND=JPIM) :: LDC
INTEGER(KIND=JPIB) :: OFFSETC(:)
INTEGER(KIND=JPIM) :: BATCHCOUNT
INTEGER(KIND=C_INT) :: STREAM
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC

INTEGER(KIND=C_LONG) :: HIP_STREAM

#ifdef ACCGPU
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
#endif
#ifdef OMPGPU
#endif

CALL HIP_DGEMM_GROUPED( &
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, OFFSETA, &
& BARRAY, LDB, OFFSETB, &
& BETA, &
& CARRAY, LDC, OFFSETC, &
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))

END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD

SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, OFFSETA, &
& BARRAY, LDB, OFFSETB, &
& BETA, &
& CARRAY, LDC, OFFSETC, &
& BATCHCOUNT, STREAM, ALLOC)
USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC
INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID
INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
INTEGER(KIND=JPIM) :: M
INTEGER(KIND=JPIM) :: N(:)
INTEGER(KIND=JPIM) :: K(:)
REAL(KIND=JPRM) :: ALPHA
REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
INTEGER(KIND=JPIM) :: LDA
INTEGER(KIND=JPIB) :: OFFSETA(:)
REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
INTEGER(KIND=JPIM) :: LDB(:)
INTEGER(KIND=JPIB) :: OFFSETB(:)
REAL(KIND=JPRM) :: BETA
REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
INTEGER(KIND=JPIM) :: LDC
INTEGER(KIND=JPIB) :: OFFSETC(:)
INTEGER(KIND=JPIM) :: BATCHCOUNT
INTEGER(KIND=C_INT) :: STREAM
TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC

INTEGER(KIND=C_LONG) :: HIP_STREAM

#ifdef ACCGPU
HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)
#endif
#ifdef OMPGPU
#endif

#if defined(_CRAYFTN)
#ifdef ACCGPU
!$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
#endif
CALL HIP_SGEMM_GROUPED( &
& RESOL_ID, BLAS_ID, TRANSA, TRANSB, &
& M, N, K, &
& ALPHA, &
& AARRAY, LDA, OFFSETA, &
& BARRAY, LDB, OFFSETB, &
& BETA, &
& CARRAY, LDC, OFFSETC, &
& BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
#ifdef ACCGPU
!$ACC END HOST_DATA
#endif
#endif

END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD

END MODULE HICBLAS_MOD
34 changes: 23 additions & 11 deletions src/trans/gpu/internal/ledir_mod.F90
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,19 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
USE TPM_GEOMETRY, ONLY: G
USE TPM_FIELDS_GPU, ONLY: FG
USE TPM_DISTR, ONLY: D
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED_OVERLOAD, &
& HIP_DGEMM_GROUPED_OVERLOAD, HIP_SGEMM_GROUPED_OVERLOAD
USE HICBLAS_MOD, ONLY: HIP_DGEMM_BATCHED, &
& HIP_DGEMM_GROUPED, HIP_SGEMM_GROUPED
USE MPL_MODULE, ONLY: MPL_BARRIER,MPL_ALL_MS_COMM
USE TPM_STATS, ONLY: GSTATS => GSTATS_NVTX
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT
USE, INTRINSIC :: ISO_C_BINDING, ONLY: C_INT, C_LONG, C_LOC
#ifdef ACCGPU
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM
#endif

#ifdef TRANS_SINGLE
#define HIP_GEMM HIP_SGEMM_GROUPED_OVERLOAD
#define HIP_GEMM HIP_SGEMM_GROUPED
#else
#define HIP_GEMM HIP_DGEMM_GROUPED_OVERLOAD
#define HIP_GEMM HIP_DGEMM_GROUPED
#endif

IMPLICIT NONE
Expand Down Expand Up @@ -149,12 +152,21 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
INTEGER(KIND=JPIM) :: IIN0_STRIDES0, IIN0_STRIDES1
INTEGER(KIND=8) :: ALLOC_SZ, ALLOC_POS

INTEGER(KIND=C_LONG) :: HIP_STREAM

ASSOCIATE(D_NUMP=>D%NUMP, R_NSMAX=>R%NSMAX, R_NTMAX=>R%NTMAX, G_NDGLU=>G%NDGLU, &
& D_MYMS=>D%MYMS, D_OFFSETS_GEMM1=>D%OFFSETS_GEMM1, &
& D_OFFSETS_GEMM2=>D%OFFSETS_GEMM2, &
& ZAA=>FG%ZAA, ZAS=>FG%ZAS, ZAA0=>FG%ZAA0, ZAS0=>FG%ZAS0)
IF (LHOOK) CALL DR_HOOK('LE_DGEMM',0,ZHOOK_HANDLE)

#ifdef ACCGPU
HIP_STREAM = INT(ACC_GET_HIP_STREAM(1_C_INT), C_LONG)
#endif
#ifdef OMPGPU
HIP_STREAM = 1_C_LONG
#endif

CALL LEDIR_STRIDES(KF_FS,IOUT_STRIDES0,IOUT_STRIDES1,IIN_STRIDES0,IIN_STRIDES1,&
IOUT0_STRIDES0,IOUT0_STRIDES1,IIN0_STRIDES0,IIN0_STRIDES1)

Expand Down Expand Up @@ -187,15 +199,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
#ifdef ACCGPU
!$ACC HOST_DATA USE_DEVICE(ZAA0,ZINPA0,ZOUT0)
#endif
CALL HIP_DGEMM_BATCHED_OVERLOAD( &
CALL HIP_DGEMM_BATCHED( &
& 'N', 'N', &
& KF_FS, (R_NSMAX+2)/2, G_NDGLU(0), &
& 1.0_JPRD, &
& ZINPA0, IIN0_STRIDES0, 0, &
& ZAA0, SIZE(ZAA0,1), 0, &
& 0.0_JPRD, &
& ZOUT0, IOUT0_STRIDES0, 0, &
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
#ifdef OMPGPU
!$OMP END TARGET DATA
#endif
Expand Down Expand Up @@ -233,7 +245,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
& ZAA, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
& 0.0_JPRBT, &
& ZOUT, IOUT_STRIDES0, COFFSETS, &
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
#ifdef OMPGPU
!$OMP END TARGET DATA
#endif
Expand Down Expand Up @@ -306,15 +318,15 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
!$ACC HOST_DATA USE_DEVICE(ZAS0,ZINPS0,ZOUT0)
#endif
! compute m=0 in double precision:
call HIP_DGEMM_BATCHED_OVERLOAD( &
call HIP_DGEMM_BATCHED( &
& 'N', 'N', &
& KF_FS, (R_NSMAX+3)/2, G_NDGLU(0), &
& 1.0_JPRD, &
& ZINPS0, IIN0_STRIDES0, 0, &
& ZAS0, SIZE(ZAS0,1), 0, &
& 0.0_JPRD, &
& ZOUT0, IOUT0_STRIDES0, 0, &
& 1, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
& 1, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
#ifdef OMPGPU
!$OMP END TARGET DATA
#endif
Expand Down Expand Up @@ -353,7 +365,7 @@ SUBROUTINE LEDIR(ALLOCATOR,ZINPS,ZINPA,ZINPS0,ZINPA0,ZOUT,ZOUT0,POA1,KF_FS)
& ZAS, D%LEGENDRE_MATRIX_STRIDES, BOFFSETS, &
& 0.0_JPRBT, &
& ZOUT, IOUT_STRIDES0, COFFSETS, &
& D_NUMP, STREAM=1_C_INT, ALLOC=ALLOCATOR%PTR)
& D_NUMP, HIP_STREAM, C_LOC(ALLOCATOR%PTR))
#ifdef OMPGPU
!$OMP END TARGET DATA
#endif
Expand Down
Loading

0 comments on commit dfaa28f

Please sign in to comment.