forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate addmm, addbmm and THBlas_gemm to ATen (pytorch#40927)
Summary: Closes pytorch#24679, closes pytorch#24678 `addbmm` depends on `addmm` so needed to be ported at the same time. I also removed `THTensor_(baddbmm)` which I noticed had already been ported so was just dead code. After having already written this code, I had to fix merge conflicts with pytorch#40354 which revealed there was already an established place for cpu blas routines in ATen. However, the version there doesn't make use of ATen's AVX dispatching so thought I'd wait for comment before migrating this into that style. Pull Request resolved: pytorch#40927 Differential Revision: D22418756 Pulled By: ezyang fbshipit-source-id: 44e7bb5964263d73ae8cc6adc5f6d4e966476ae6
- Loading branch information
1 parent
3f32332
commit 6725c03
Showing
19 changed files
with
775 additions
and
660 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
#include <ATen/native/CPUBlas.h> | ||
#include <ATen/Config.h> | ||
|
||
#include <climits> | ||
|
||
#if AT_BUILD_WITH_BLAS() | ||
extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc); | ||
extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc); | ||
extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); | ||
extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); | ||
#endif // AT_BUILD_WITH_BLAS() | ||
|
||
#ifdef USE_FBGEMM | ||
#include <fbgemm/FbgemmI64.h> | ||
#endif // USE_FBGEMM | ||
|
||
namespace at { | ||
namespace native { | ||
namespace cpublas { | ||
namespace internal { | ||
|
||
void normalize_last_dims( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
int64_t *lda, int64_t *ldb, int64_t *ldc) { | ||
if (n == 1) { | ||
*ldc = m; | ||
} | ||
|
||
if(transa != NoTranspose) { | ||
if (m == 1) { | ||
*lda = k; | ||
} | ||
} else if(k == 1) { | ||
*lda = m; | ||
} | ||
|
||
if(transb != NoTranspose) { | ||
if (k == 1) { | ||
*ldb = n; | ||
} | ||
} else if (n == 1) { | ||
*ldb = k; | ||
} | ||
} | ||
} // namespace internal | ||
|
||
namespace { | ||
|
||
bool use_blas_gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
int64_t &lda, int64_t &ldb, int64_t &ldc) { | ||
const bool transa_ = transa != NoTranspose; | ||
const bool transb_ = transb != NoTranspose; | ||
return ( | ||
(m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && | ||
(lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) && | ||
(lda >= std::max(int64_t{1}, (transa_ ? k : m))) && | ||
(ldb >= std::max(int64_t{1}, (transb_ ? n : k))) && | ||
(ldc >= std::max(int64_t{1}, m))); | ||
} | ||
|
||
#if AT_BUILD_WITH_BLAS() | ||
char to_blas(TransposeType trans) { | ||
switch (trans) { | ||
case Transpose: return 't'; | ||
case NoTranspose: return 'n'; | ||
// case ConjTranspose: return 'c'; | ||
} | ||
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); | ||
} | ||
#endif // AT_BUILD_WITH_BLAS | ||
|
||
#ifdef USE_FBGEMM | ||
fbgemm::matrix_op_t to_fbgemm(TransposeType trans) { | ||
switch (trans) { | ||
case Transpose: return fbgemm::matrix_op_t::Transpose; | ||
case NoTranspose: return fbgemm::matrix_op_t::NoTranspose; | ||
// case ConjTranspose: return fbgemm::matrix_op_t::Transpose; | ||
} | ||
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type"); | ||
} | ||
#endif // USE_FBGEMM | ||
|
||
} // namespace (anonymous) | ||
|
||
DEFINE_DISPATCH(gemm_stub); | ||
|
||
void gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
const double alpha, | ||
const double *a, int64_t lda, | ||
const double *b, int64_t ldb, | ||
const double beta, | ||
double *c, int64_t ldc) { | ||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
#if AT_BUILD_WITH_BLAS() | ||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { | ||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; | ||
char transa_ = to_blas(transa), transb_ = to_blas(transb); | ||
double alpha_ = alpha, beta_ = beta; | ||
dgemm_( | ||
&transa_, &transb_, | ||
&m_, &n_, &k_, | ||
&alpha_, | ||
a, &lda_, | ||
b, &ldb_, | ||
&beta_, | ||
c, &ldc_); | ||
return; | ||
} | ||
#endif | ||
gemm_stub( | ||
at::kCPU, at::kDouble, | ||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | ||
} | ||
|
||
void gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
const float alpha, | ||
const float *a, int64_t lda, | ||
const float *b, int64_t ldb, | ||
const float beta, | ||
float *c, int64_t ldc) { | ||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
#if AT_BUILD_WITH_BLAS() | ||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { | ||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; | ||
char transa_ = to_blas(transa), transb_ = to_blas(transb); | ||
float alpha_ = alpha, beta_ = beta; | ||
sgemm_( | ||
&transa_, &transb_, | ||
&m_, &n_, &k_, | ||
&alpha_, | ||
a, &lda_, | ||
b, &ldb_, | ||
&beta_, | ||
c, &ldc_); | ||
return; | ||
} | ||
#endif | ||
gemm_stub( | ||
at::kCPU, at::kFloat, | ||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | ||
} | ||
|
||
void gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
const c10::complex<double> alpha, | ||
const c10::complex<double> *a, int64_t lda, | ||
const c10::complex<double> *b, int64_t ldb, | ||
const c10::complex<double> beta, | ||
c10::complex<double> *c, int64_t ldc) { | ||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
#if AT_BUILD_WITH_BLAS() | ||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { | ||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; | ||
char transa_ = to_blas(transa), transb_ = to_blas(transb); | ||
c10::complex<double> alpha_ = alpha, beta_ = beta; | ||
zgemm_( | ||
&transa_, &transb_, | ||
&m_, &n_, &k_, | ||
&alpha_, | ||
a, &lda_, | ||
b, &ldb_, | ||
&beta_, | ||
c, &ldc_); | ||
return; | ||
} | ||
#endif | ||
gemm_stub( | ||
at::kCPU, at::kComplexDouble, | ||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | ||
} | ||
|
||
void gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
const c10::complex<float> alpha, | ||
const c10::complex<float> *a, int64_t lda, | ||
const c10::complex<float> *b, int64_t ldb, | ||
const c10::complex<float> beta, | ||
c10::complex<float> *c, int64_t ldc) { | ||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
#if AT_BUILD_WITH_BLAS() | ||
if (use_blas_gemm(transa, transb, m, n, k, lda, ldb, ldc)) { | ||
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; | ||
char transa_ = to_blas(transa), transb_ = to_blas(transb); | ||
c10::complex<float> alpha_ = alpha, beta_ = beta; | ||
cgemm_( | ||
&transa_, &transb_, | ||
&m_, &n_, &k_, | ||
&alpha_, | ||
a, &lda_, | ||
b, &ldb_, | ||
&beta_, | ||
c, &ldc_); | ||
return; | ||
} | ||
#endif | ||
gemm_stub( | ||
at::kCPU, at::kComplexFloat, | ||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | ||
} | ||
|
||
void gemm( | ||
TransposeType transa, TransposeType transb, | ||
int64_t m, int64_t n, int64_t k, | ||
const int64_t alpha, | ||
const int64_t *a, int64_t lda, | ||
const int64_t *b, int64_t ldb, | ||
const int64_t beta, | ||
int64_t *c, int64_t ldc) { | ||
internal::normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); | ||
#ifdef USE_FBGEMM | ||
if (alpha == 1 && (beta == 0 || beta == 1)) { | ||
// In FBGEMM, we assume row-major ordering; However, here we assume the | ||
// column-major ordering following the FORTRAN tradition in BLAS interface | ||
// in this function: we can configure the layout (row/column-major ordering) | ||
// of A and B by changing transa_ and transb_, but we cannot change the | ||
// layout of C with this FORTRAN-style BLAS interface. | ||
// | ||
// The workaround is that we compute | ||
// C^T (n x m) = B^T (n x k) * A^T (k x m) instead. | ||
// | ||
// In this way we view C^T as the row-major ordering when passing to FBGEMM. | ||
fbgemm::cblas_gemm_i64_i64acc( | ||
to_fbgemm(transb), | ||
to_fbgemm(transa), | ||
n, | ||
m, | ||
k, | ||
b, | ||
ldb, | ||
a, | ||
lda, | ||
beta == 1, | ||
c, | ||
ldc); | ||
return; | ||
} | ||
#endif | ||
|
||
gemm_stub( | ||
kCPU, kLong, | ||
transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); | ||
} | ||
|
||
}}} // namespace at::native::cpublas |
Oops, something went wrong.