diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 6b9d3c67d3a7d8..6ce7cda1e55761 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -32,6 +32,7 @@ "add_n", "addmm", "any", + "baddbmm", "bce_loss", "bmm", "diag", diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 3f2c8397a61415..5c5a36ae186c32 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -163,6 +163,67 @@ bool Addmm_OpInferSymbolicShape(pir::Operation *op, return AddmmOpInferSymbolicShape(op, infer_context); } +bool BaddbmmOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &y_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + + auto ndim_input = input_shape.shape().size(); + auto ndim_x = x_shape.shape().size(); + auto ndim_y = y_shape.shape().size(); + + PADDLE_ENFORCE_EQ(ndim_input, + 3, + common::errors::InvalidArgument( + "The input tensor input's dimension must be 3. " + "But received input's dimension = [%d].", + ndim_input)); + PADDLE_ENFORCE_EQ(ndim_x, + 3, + common::errors::InvalidArgument( + "The input tensor x's dimension must be 3. " + "But received x's dimension = [%d].", + ndim_x)); + PADDLE_ENFORCE_EQ(ndim_y, + 3, + common::errors::InvalidArgument( + "The input tensor y's dimension must be 3. " + "But received y's dimension = [%d].", + ndim_y)); + + std::vector output_shape; + output_shape.push_back(x_shape.shape()[0]); // batch size + output_shape.push_back(x_shape.shape()[1]); + output_shape.push_back(y_shape.shape()[2]); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(output_shape)}); + + infer_context->AddEqualCstr(x_shape.shape()[0], + y_shape.shape()[0]); // batch size + infer_context->AddEqualCstr(x_shape.shape()[2], y_shape.shape()[1]); + + infer_context->AddBroadcastableCstr(input_shape.shape()[0], + x_shape.shape()[0]); // batch size + infer_context->AddBroadcastableCstr(input_shape.shape()[1], + x_shape.shape()[1]); + infer_context->AddBroadcastableCstr(input_shape.shape()[2], + y_shape.shape()[2]); + + return true; +} + +bool Baddbmm_OpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return BaddbmmOpInferSymbolicShape(op, infer_context); +} + bool AucOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &predict_shape = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index c8dd6a2b048ce9..d647580d0242c0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -21,6 +21,8 @@ namespace paddle::dialect { OP_DECLARE_INFER_SYMBOLIC_SHAPE(Accuracy) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Baddbmm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Baddbmm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) diff --git a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h index 357b2434c1f676..020f4741ef63ad 100644 --- a/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h +++ b/paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h @@ -1409,6 +1409,28 @@ Tensor addmm_decomp(const Tensor& input, full_scalar(beta, input.dtype()) * input; } +template +Tensor baddbmm_decomp(const Tensor& input, + const Tensor& x, + const Tensor& y, + const float beta, + const float alpha) { + int batch_size = x.shape()[0]; + std::vector batch_results; + + for (int i = 0; i < batch_size; ++i) { + Tensor x_batch = get_slice(x, i); + Tensor y_batch = get_slice(y, i); + Tensor result = matmul(x_batch, y_batch); + batch_results.push_back(result); + } + + Tensor x_y_mat = concat(batch_results); + + return full_scalar(alpha, x_y_mat.dtype()) * x_y_mat + + full_scalar(beta, input.dtype()) * input; +} + template Tensor eye_decomp(const paddle::Scalar& num_rows, const paddle::Scalar& num_columns, diff --git a/paddle/phi/api/ext/tensor_compat.h b/paddle/phi/api/ext/tensor_compat.h index b1a140da46a890..6c09a4f7451c1e 100644 --- a/paddle/phi/api/ext/tensor_compat.h +++ b/paddle/phi/api/ext/tensor_compat.h @@ -35,6 +35,7 @@ using experimental::asinh; using experimental::atan; using experimental::atan2; using experimental::atanh; +using experimental::baddbmm; using experimental::bernoulli; using experimental::ceil; using experimental::cholesky; diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index e2566301a45b23..d04cfcd59bd57b 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -150,6 +150,96 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void BaddbmmInferMeta(const MetaTensor& input, + const MetaTensor& x, + const MetaTensor& y, + float beta, + float alpha, + MetaTensor* out) { + auto input_dims = input.dims(); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + auto ndim_input = input_dims.size(); + auto ndim_x = x_dims.size(); + auto ndim_y = y_dims.size(); + + VLOG(3) << "baddbmm operator input.shape=" << input_dims + << " x.shape=" << x_dims << " y.shape=" << y_dims << " beta=" << beta + << " alpha=" << alpha << " ndim_input=" << ndim_input + << " ndim_x=" << ndim_x << " ndim_y=" << ndim_y; + + PADDLE_ENFORCE_NE( + product(input_dims), + 0, + errors::PreconditionNotMet("The Input variable 'input' has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + + PADDLE_ENFORCE_NE( + product(x_dims), + 0, + errors::PreconditionNotMet("The Input variable 'x' has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + + PADDLE_ENFORCE_NE( + product(y_dims), + 0, + errors::PreconditionNotMet("The Input variable 'y' has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + // dim check + PADDLE_ENFORCE_EQ( + ndim_input, + 3, + errors::InvalidArgument("The input tensor input's dimension must be 3. " + "But received input's dimension = [%d].", + ndim_input)); + PADDLE_ENFORCE_EQ( + ndim_x, + 3, + errors::InvalidArgument("The input tensor x's dimension must be 3. " + "But received x's dimension = [%d].", + ndim_x)); + PADDLE_ENFORCE_EQ( + ndim_y, + 3, + errors::InvalidArgument("The input tensor y's dimension must be 3. " + "But received y's dimension = [%d].", + ndim_y)); + + PADDLE_ENFORCE_EQ( + input_dims[0], + x_dims[0], + errors::InvalidArgument( + "The batch size of input and x must be the same. " + "But received input batch size = [%d], x batch size = [%d].", + input_dims[0], + x_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[2], + y_dims[1], + errors::InvalidArgument("The second dimension of x must be equal to the " + "first dimension of y. " + "But received x's second dimension = [%d], y's " + "first dimension = [%d].", + x_dims[2], + y_dims[1])); + + std::vector output_dims; + output_dims.push_back(x_dims[0]); + output_dims.push_back(x_dims[1]); + output_dims.push_back(y_dims[2]); + + out->set_dims(common::make_ddim(output_dims)); + out->share_lod(input); + out->set_dtype(input.dtype()); +} + void AffineChannelInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index b05e64b4262123..ee7f484c5d2035 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -48,6 +48,13 @@ void AddmmInferMeta(const MetaTensor& input, float alpha, MetaTensor* out); +void BaddbmmInferMeta(const MetaTensor& input, + const MetaTensor& x, + const MetaTensor& y, + float beta, + float alpha, + MetaTensor* out); + void AffineChannelInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, diff --git a/paddle/phi/kernels/baddbmm_grad_kernel.h b/paddle/phi/kernels/baddbmm_grad_kernel.h new file mode 100644 index 00000000000000..34d237e379cb6e --- /dev/null +++ b/paddle/phi/kernels/baddbmm_grad_kernel.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BaddbmmGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + float alpha, + float beta, + DenseTensor* input_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/baddbmm_kernel.h b/paddle/phi/kernels/baddbmm_kernel.h new file mode 100644 index 00000000000000..a10a89d4beb44c --- /dev/null +++ b/paddle/phi/kernels/baddbmm_kernel.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BaddbmmKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& x, + const DenseTensor& y, + float beta, + float alpha, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/baddbmm_grad_kernel.cc b/paddle/phi/kernels/cpu/baddbmm_grad_kernel.cc new file mode 100644 index 00000000000000..cfd36c73c9cb8c --- /dev/null +++ b/paddle/phi/kernels/cpu/baddbmm_grad_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/baddbmm_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + baddbmm_grad, CPU, ALL_LAYOUT, phi::BaddbmmGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/baddbmm_kernel.cc b/paddle/phi/kernels/cpu/baddbmm_kernel.cc new file mode 100644 index 00000000000000..7b616c924bf95c --- /dev/null +++ b/paddle/phi/kernels/cpu/baddbmm_kernel.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/baddbmm_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/baddbmm_kernel_impl.h" + +PD_REGISTER_KERNEL( + baddbmm, CPU, ALL_LAYOUT, phi::BaddbmmKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/blas/blas.h b/paddle/phi/kernels/funcs/blas/blas.h index 2f27682247bdce..5e42107312138f 100644 --- a/paddle/phi/kernels/funcs/blas/blas.h +++ b/paddle/phi/kernels/funcs/blas/blas.h @@ -96,6 +96,18 @@ class Blas { T beta, T* C) const; + template + void GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T* A, + const T* B, + U beta, + T* C) const; + template void GEMM(bool transA, bool transB, @@ -292,6 +304,21 @@ class Blas { int64_t strideA, int64_t strideB) const; + template + void BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T* A, + const T* B, + U beta, + T* C, + int batchCount, + int64_t strideA, + int64_t strideB) const; + template void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 5fcc3f12f2b351..096ab5bd857ed8 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -1183,6 +1183,152 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 8000 } +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T *A, + const T *B, + U beta, + T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + T t_alpha = static_cast(alpha); + T t_beta = static_cast(beta); + +#if CUDA_VERSION >= 8000 + if (FLAGS_enable_cublas_tensor_op_math && std::is_same::value) { + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &t_alpha, + B, + CUDA_R_32F, + ldb, + A, + CUDA_R_32F, + lda, + &t_beta, + C, + CUDA_R_32F, + N); + } else { +#endif // CUDA_VERSION >= 8000 + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &t_alpha, + B, + ldb, + A, + lda, + &t_beta, + C, + N); + }); + +#if CUDA_VERSION >= 8000 + } +#endif // CUDA_VERSION >= 8000 +} + +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const phi::dtype::float16 *A, + const phi::dtype::float16 *B, + float beta, + phi::dtype::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 53, + common::errors::InvalidArgument( + "cublas fp16 gemm requires GPU compute capability >= 53," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = alpha; + float h_beta = beta; + +#if CUDA_VERSION >= 8000 + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + auto &cuda_ctx = const_cast(context_); + CUBlas::GEMM_EX(&cuda_ctx, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16F, + ldb, + A, + CUDA_R_16F, + lda, + &h_beta, + C, + CUDA_R_16F, + N, + CUDA_R_32F); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + h_B, + ldb, + h_A, + lda, + &h_beta, + h_C, + N); + }); +#endif // CUDA_VERSION >= 8000 +} + template <> template <> inline void Blas::GEMM(CBLAS_TRANSPOSE transA, @@ -1252,6 +1398,75 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 11000 } +template <> +template <> +inline void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const phi::dtype::bfloat16 *A, + const phi::dtype::bfloat16 *B, + float beta, + phi::dtype::bfloat16 *C) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + common::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = alpha; + float h_beta = beta; + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + N, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(common::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + template <> template <> inline void Blas::GEMM(CBLAS_TRANSPOSE transA, @@ -1781,6 +1996,125 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 9010 } +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T *A, + const T *B, + U beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + +#if CUDA_VERSION >= 9010 + if ((FLAGS_enable_cublas_tensor_op_math && (std::is_same::value)) || + std::is_same::value) { + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " + << (use_tensor_op_math ? "True" : "False"); + VLOG(4) << "use_half_precision_compute_type: " + << FLAGS_gemm_use_half_precision_compute_type; + + auto fp = std::is_same::value ? CUDA_R_32F : CUDA_R_16F; +#if CUDA_VERSION >= 11000 + auto compute_type = CUBLAS_COMPUTE_32F; +#else + auto compute_type = CUDA_R_32F; +#endif + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + void *a = static_cast(&h_alpha); + void *b = static_cast(&h_beta); + // set ComputeType as CUDA_R_32F for fp16, for better accuracy + if (FLAGS_gemm_use_half_precision_compute_type == true && + std::is_same::value) { + a = static_cast(&alpha); + b = static_cast(&beta); +#if CUDA_VERSION >= 11000 + compute_type = CUBLAS_COMPUTE_16F; +#else + compute_type = CUDA_R_16F; +#endif + } + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + a, + B, + fp, + ldb, + strideB, + A, + fp, + lda, + strideA, + b, + C, + fp, + ldc, + strideC, + batchCount, + compute_type, + algo)); + }); + } else { +#endif // CUDA_VERSION >= 9010 + + T h_alpha = static_cast(alpha); + T h_beta = static_cast(beta); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GEMM_STRIDED_BATCH(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + ldb, + strideB, + A, + lda, + strideA, + &h_beta, + C, + ldc, + strideC, + batchCount); + }); + +#if CUDA_VERSION >= 9010 + } +#endif // CUDA_VERSION >= 9010 +} + template <> template <> inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, @@ -1852,6 +2186,77 @@ inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif // CUDA_VERSION >= 11000 } +template <> +template <> +inline void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + float alpha, + const phi::dtype::bfloat16 *A, + const phi::dtype::bfloat16 *B, + float beta, + phi::dtype::bfloat16 *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t strideC = M * N; + + float h_alpha = alpha; + float h_beta = beta; + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasGemmStridedBatchedEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + strideB, + A, + CUDA_R_16BF, + lda, + strideA, + &h_beta, + C, + CUDA_R_16BF, + ldc, + strideC, + batchCount, + CUBLAS_COMPUTE_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(common::errors::Unimplemented( + "cublasGemmStridedBatchedEx with bfloat16 is not supported on cuda <= " + "11")); +#endif // CUDA_VERSION >= 11000 +} + template <> template void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.h b/paddle/phi/kernels/funcs/blas/blas_impl.h index 098e37105a45dc..4c6bcca8d63d73 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.h @@ -1080,6 +1080,37 @@ void Blas::GEMM(CBLAS_TRANSPOSE transA, ldc); } +template <> +template +void Blas::GEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T *A, + const T *B, + U beta, + T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + template <> template void Blas::GEMM(bool transA, @@ -1410,6 +1441,66 @@ void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, #endif } +template <> +template +void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, + CBLAS_TRANSPOSE transB, + int M, + int N, + int K, + U alpha, + const T *A, + const T *B, + U beta, + T *C, + int batchCount, + int64_t strideA, + int64_t strideB) const { + PADDLE_ENFORCE_NOT_NULL( + A, common::errors::InvalidArgument("Pointer A should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + B, common::errors::InvalidArgument("Pointer B should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + C, common::errors::InvalidArgument("Pointer C should not be null.")); +#ifdef PADDLE_WITH_MKLML + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + auto a_array = std::vector(batchCount); + auto b_array = std::vector(batchCount); + auto c_array = std::vector(batchCount); + for (int k = 0; k < batchCount; ++k) { + a_array[k] = &A[k * strideA]; + b_array[k] = &B[k * strideB]; + c_array[k] = &C[k * M * N]; + } + + CBlas::GEMM_BATCH(CblasRowMajor, + &transA, + &transB, + &M, + &N, + &K, + &alpha, + a_array.data(), + &lda, + b_array.data(), + &ldb, + &beta, + c_array.data(), + &ldc, + 1 /* group_count */, + &batchCount); +#else + for (int k = 0; k < batchCount; ++k) { + auto *Ak = &A[k * strideA]; + auto *Bk = &B[k * strideB]; + auto *Ck = &C[k * M * N]; + this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); + } +#endif +} + template <> template void Blas::BatchedGEMM(CBLAS_TRANSPOSE transA, diff --git a/paddle/phi/kernels/gpu/baddbmm_grad_kernel.cu b/paddle/phi/kernels/gpu/baddbmm_grad_kernel.cu new file mode 100644 index 00000000000000..5dcf03c7458ad0 --- /dev/null +++ b/paddle/phi/kernels/gpu/baddbmm_grad_kernel.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/baddbmm_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(baddbmm_grad, + GPU, + ALL_LAYOUT, + phi::BaddbmmGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/baddbmm_kernel.cu b/paddle/phi/kernels/gpu/baddbmm_kernel.cu new file mode 100644 index 00000000000000..0e41074119eee0 --- /dev/null +++ b/paddle/phi/kernels/gpu/baddbmm_kernel.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/baddbmm_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/baddbmm_kernel_impl.h" + +PD_REGISTER_KERNEL(baddbmm, + GPU, + ALL_LAYOUT, + phi::BaddbmmKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h new file mode 100644 index 00000000000000..238f50c5551947 --- /dev/null +++ b/paddle/phi/kernels/impl/baddbmm_grad_kernel_impl.h @@ -0,0 +1,258 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include + +#include "glog/logging.h" + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/baddbmm_grad_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/for_range.h" + +namespace phi { + +template +struct BCopyOrScaleFunctor { + BCopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel) + : scale_(scale), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MPType = typename phi::dtype::MPTypeTrait::Type; + const MPType mp_scale = static_cast(scale_); + const MPType mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(mp_scale * mp_x); + } + + private: + const float scale_; + const T* x_; + T* output_; + int64_t numel_; +}; + +template +using PhiEigenTensor = EigenTensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using Array3 = Eigen::DSizes; + +template +void BaddbmmGradKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + float alpha, + float beta, + DenseTensor* input_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + bool is_float16_or_bfloat16 = false; + if (std::is_same::value || + std::is_same::value) { + is_float16_or_bfloat16 = true; + } + + auto in_dims = input.dims(); + int total_elems = 0; + + VLOG(3) << "alpha: " << alpha << " beta: " << beta; + + if (input_grad != nullptr) { + input_grad->set_lod(out_grad.lod()); + } + if (x_grad != nullptr) { + x_grad->set_lod(x.lod()); + } + if (y_grad != nullptr) { + y_grad->set_lod(y.lod()); + } + + auto blas = funcs::GetBlas(dev_ctx); + auto mt_blas = funcs::GetBlas(dev_ctx); + if (input_grad) { + dev_ctx.template Alloc(input_grad); + total_elems = in_dims[0] * in_dims[1] * in_dims[2]; + auto& place = *dev_ctx.eigen_device(); + auto eigen_dout = PhiEigenTensor::From(out_grad); + auto eigen_dinput = PhiEigenTensor::From(*input_grad); + + bool batch_compress = in_dims[0] != out_grad.dims()[0]; + bool row_compress = in_dims[1] != out_grad.dims()[1]; + bool col_compress = in_dims[2] != out_grad.dims()[2]; + auto eigen_dinput_shape = Array3( + input_grad->dims()[0], input_grad->dims()[1], input_grad->dims()[2]); + + if (batch_compress && row_compress && col_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum().eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum() + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (batch_compress && row_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array2(0, 1)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array2(0, 1)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (batch_compress && col_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array2(0, 2)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array2(0, 2)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (row_compress && col_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array2(1, 2)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array2(1, 2)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (batch_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array1(0)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (row_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array1(1)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else if (col_compress) { + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(2)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array1(2)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } + } else { + // The VCOPY does not support the float16, bfloat16 + if (!is_float16_or_bfloat16) { + mt_blas.VCOPY( + total_elems, out_grad.data(), input_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + BCopyOrScaleFunctor functor( + 1, out_grad.data(), input_grad->data(), total_elems); + for_range(functor); + } + } + + // The SCAL does not support the float16, bfloat16 + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, beta, input_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + BCopyOrScaleFunctor functor( + beta, input_grad->data(), input_grad->data(), total_elems); + for_range(functor); + } + } + if (x_grad) { + dev_ctx.template Alloc(x_grad); + total_elems = x.dims()[0] * x.dims()[1] * x.dims()[2]; + // x_grad = out_grad * y'. x_grad: B x M x K, out_grad : B x M x N, y : B x + // K x N + for (int i = 0; i < x.dims()[0]; ++i) { + auto out_grad_slice = out_grad.Slice(i, i + 1); + auto y_slice = y.Slice(i, i + 1); + auto x_grad_slice = x_grad->Slice(i, i + 1); + auto x_grad_dims = x_grad_slice.dims(); + + x_grad_slice.Resize({x_grad_dims[1], x_grad_dims[2]}); + y_slice.Resize({y_slice.dims()[1], y_slice.dims()[2]}); + out_grad_slice.Resize( + {out_grad_slice.dims()[1], out_grad_slice.dims()[2]}); + blas.MatMul(out_grad_slice, false, y_slice, true, &x_grad_slice); + } + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, alpha, x_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + BCopyOrScaleFunctor functor( + alpha, x_grad->data(), x_grad->data(), total_elems); + for_range(functor); + } + } + if (y_grad) { + dev_ctx.template Alloc(y_grad); + total_elems = y.dims()[0] * y.dims()[1] * y.dims()[2]; + // y_grad = x' * out_grad. y_grad: B x K x N, out_grad : B x M x N, x : B x + // M x K + for (int i = 0; i < x.dims()[0]; ++i) { + auto out_grad_slice = out_grad.Slice(i, i + 1); + auto x_slice = x.Slice(i, i + 1); + auto y_grad_slice = y_grad->Slice(i, i + 1); + out_grad_slice.Resize( + {out_grad_slice.dims()[1], out_grad_slice.dims()[2]}); + x_slice.Resize({x_slice.dims()[1], x_slice.dims()[2]}); + y_grad_slice.Resize({y_grad_slice.dims()[1], y_grad_slice.dims()[2]}); + blas.MatMul(x_slice, true, out_grad_slice, false, &y_grad_slice); + } + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, alpha, y_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + BCopyOrScaleFunctor functor( + alpha, y_grad->data(), y_grad->data(), total_elems); + for_range(functor); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/baddbmm_kernel_impl.h b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h new file mode 100644 index 00000000000000..83fedc4bcbc626 --- /dev/null +++ b/paddle/phi/kernels/impl/baddbmm_kernel_impl.h @@ -0,0 +1,185 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "glog/logging.h" + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/baddbmm_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +using PhiEigenTensor = EigenTensor; + +using Array1 = Eigen::DSizes; +using Array2 = Eigen::DSizes; +using Array3 = Eigen::DSizes; + +template +void BaddbmmKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& x, + const DenseTensor& y, + float beta, + float alpha, + DenseTensor* out) { + auto input_dims = input.dims(); + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + DenseTensor input_3d(input); + if (input.dims().size() == 2) { + input_dims = {1, input.dims()[0], input.dims()[1]}; + input_3d.Resize(input_dims); + } + + // broadcast mode check + if (x_dims[0] != input_dims[0]) { + PADDLE_ENFORCE_EQ(input_dims[0], + 1, + errors::InvalidArgument( + "When x_dims[0] is not equal with input_dims[0], " + "input_dims[0] must be 1 but got %s", + input_dims[0])); + PADDLE_ENFORCE_EQ(y_dims[2] == input_dims[2] || input_dims[2] == 1, + true, + errors::InvalidArgument( + "The input tensor shape mismatch, input shape=[%s], " + "x shape=[%s], y shape=[%s]", + input_dims, + x_dims, + y_dims)); + } + if (y_dims[2] != input_dims[2]) { + PADDLE_ENFORCE_EQ(input_dims[2], + 1, + errors::InvalidArgument( + "When y_dims[2] is not equal with input_dims[2], " + "input_dims[2] must be 1 but got %s", + input_dims[2])); + PADDLE_ENFORCE_EQ(x_dims[0] == input_dims[0] || input_dims[0] == 1, + true, + errors::InvalidArgument( + "The input tensor shape mismatch, input shape=[%s], " + "x shape=[%s], y shape=[%s]", + input_dims, + x_dims, + y_dims)); + } + PADDLE_ENFORCE_EQ( + x_dims[2], + y_dims[1], + errors::InvalidArgument( + "The input tensor X's width must be equal with matrix Y' height. " + "But received X's shape = [%s], Y's shape = [%s].", + x_dims[2], + y_dims[1])); + + dev_ctx.template Alloc(out); + auto blas = funcs::GetBlas(dev_ctx); + + // calc broadcast dim + Array3 bcast_dims; + bcast_dims[0] = x_dims[0] / input_dims[0]; + bcast_dims[1] = x_dims[1] / input_dims[1]; + bcast_dims[2] = y_dims[2] / input_dims[2]; + VLOG(3) << "bcast_dims=[" << bcast_dims[0] << "," << bcast_dims[1] << "," + << bcast_dims[2] << "]"; + + // broadcast using eigen + const DenseTensor& const_ref_input = input_3d; + auto eigen_input = PhiEigenTensor::From(const_ref_input); + auto eigen_out = PhiEigenTensor::From(*out); + auto& place = *dev_ctx.eigen_device(); + funcs::EigenBroadcast, T, 3>::Eval( + place, eigen_out, eigen_input, bcast_dims); + + using MPType = typename phi::dtype::MPTypeTrait::Type; + + // special case for MPType + if constexpr (std::is_same_v) { + VLOG(4) << "Function: baddbmm, Type of T: " << typeid(T).name(); + VLOG(4) << "Function: baddbmm, Type of MPType: " << typeid(MPType).name(); + float t_alpha = alpha; + float t_beta = beta; + if (x_dims[0] == 1) { + blas.GEMM(CblasNoTrans, + CblasNoTrans, + x_dims[1], + y_dims[2], + x_dims[2], + t_alpha, + x.data(), + y.data(), + t_beta, + out->data()); + } else { + blas.BatchedGEMM(CblasNoTrans, + CblasNoTrans, + x_dims[1], + y_dims[2], + x_dims[2], + t_alpha, + x.data(), + y.data(), + t_beta, + out->data(), + x_dims[0], + x_dims[1] * x_dims[2], + x_dims[2] * y_dims[2]); + } + } else { + T t_alpha = static_cast(alpha); + T t_beta = static_cast(beta); + if (x_dims[0] == 1) { + blas.GEMM(CblasNoTrans, + CblasNoTrans, + x_dims[1], + y_dims[2], + x_dims[2], + t_alpha, + x.data(), + y.data(), + t_beta, + out->data()); + } else { + blas.BatchedGEMM(CblasNoTrans, + CblasNoTrans, + x_dims[1], + y_dims[2], + x_dims[2], + t_alpha, + x.data(), + y.data(), + t_beta, + out->data(), + x_dims[0], + x_dims[1] * x_dims[2], + x_dims[2] * y_dims[2]); + // x_dims[2] == y_dims[1] + } + } +} + +} // namespace phi diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index dae96ff02fffe9..aa09c21f77fb93 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -223,6 +223,16 @@ func : atanh_grad inplace : (out_grad -> x_grad) +- backward_op : baddbmm_grad + forward : baddbmm (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0) -> Tensor(out) + args : (Tensor input, Tensor x, Tensor y, Tensor out_grad, float alpha, float beta) + output : Tensor(input_grad), Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [input, x, y] + kernel : + func : baddbmm_grad + - backward_op : batch_fc_grad forward : batch_fc (Tensor input, Tensor w, Tensor bias) -> Tensor(out) args : (Tensor input, Tensor w, Tensor bias, Tensor out_grad) diff --git a/paddle/phi/ops/yaml/op_compat.yaml b/paddle/phi/ops/yaml/op_compat.yaml index 89a91aa264893a..d3707eef9a361c 100755 --- a/paddle/phi/ops/yaml/op_compat.yaml +++ b/paddle/phi/ops/yaml/op_compat.yaml @@ -372,6 +372,17 @@ outputs : {auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut} +- op : baddbmm + backward : baddbmm_grad + inputs : + {input : Input, x : X, y : Y} + outputs : + out : Out + attrs : + {alpha : Alpha, beta : Beta} + extra : + attrs : [bool use_mkldnn = false] + - op : barrier inputs : {x : X} diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 2818ec6d89343c..6ac560334b5a29 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -490,6 +490,18 @@ inplace : (in_sum_1 -> out_sum_1), (in_sum_2 -> out_sum_2), (in_sum_3 -> out_sum_3), (in_num_accumulates -> out_num_accumulates), (in_old_num_accumulates -> out_old_num_accumulates), (in_num_updates -> out_num_updates) traits : paddle::dialect::ForwardOnlyTrait +- op : baddbmm + args : (Tensor input, Tensor x, Tensor y, float beta=1.0, float alpha=1.0) + output : Tensor(out) + infer_meta : + func : BaddbmmInferMeta + kernel : + func : baddbmm + data_type : x + inplace: (input -> out) + backward : baddbmm_grad + # interfaces : paddle::dialect::InferSymbolicShapeInterface + - op : barrier args : (Tensor x, int ring_id=0) output : Tensor(out) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7cac26bda3790e..b6659a5b583cb9 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -386,6 +386,8 @@ atan_, atanh, atanh_, + baddbmm, + baddbmm_, bitwise_left_shift, bitwise_left_shift_, bitwise_right_shift, @@ -805,6 +807,8 @@ 'raw', 'addmm', 'addmm_', + 'baddbmm', + 'baddbmm_', 'allclose', 'isclose', 't', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ff8eed1aa17159..6ff77cad6fcac8 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -257,6 +257,8 @@ atan_, atanh, atanh_, + baddbmm, + baddbmm_, bitwise_left_shift, bitwise_left_shift_, bitwise_right_shift, @@ -612,6 +614,8 @@ 'erf', 'addmm', 'addmm_', + 'baddbmm', + 'baddbmm_', 'clip', 'clip_', 'trace', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 426d7c979bc915..5e75c10eaff557 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2543,6 +2543,163 @@ def addmm_( return _C_ops.addmm_(input, x, y, beta, alpha) +def baddbmm( + input: Tensor, + x: Tensor, + y: Tensor, + beta: float = 1.0, + alpha: float = 1.0, + name: str | None = None, +) -> Tensor: + """ + **baddbmm** + + Perform batch matrix multiplication for input $x$ and $y$. + $input$ is added to the final result. + The equation is: + + .. math:: + Out = alpha * x * y + beta * input + + $Input$, $x$ and $y$ can carry the LoD (Level of Details) information, or not. But the output only shares the LoD information with input $input$. + + Args: + input (Tensor): The input Tensor to be added to the final result. + x (Tensor): The first input Tensor for batch matrix multiplication. + y (Tensor): The second input Tensor for batch matrix multiplication. + beta (float, optional): Coefficient of $input$, default is 1. + alpha (float, optional): Coefficient of $x*y$, default is 1. + name (str|None, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The output Tensor of baddbmm. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.ones([2, 2, 2]) + >>> y = paddle.ones([2, 2, 2]) + >>> input = paddle.ones([2, 2, 2]) + + >>> out = paddle.baddbmm(input=input, x=x, y=y, beta=0.5, alpha=5.0) + + >>> print(out) + Tensor(shape=[2, 2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[[10.50000000, 10.50000000], + [10.50000000, 10.50000000]], + [[10.50000000, 10.50000000], + [10.50000000, 10.50000000]]]) + """ + input_shape = input.shape + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 3: + raise ValueError( + f"The dimension of x, y should be 3 but receive x's shape: {x_shape}, y's shape: {y_shape}" + ) + if x_shape[2] != y_shape[1]: + raise ValueError( + f"The input Variable x's width must be equal with Variable y's height. But received x's shape = {x_shape}, y's shape = {y_shape}." + ) + + if len(input_shape) == 3: + if input_shape[0] != x_shape[0]: + raise ValueError( + f"The batch size of input must be equal to the batch size of x. But received input's batch size = {input_shape[0]}, x's batch size = {x_shape[0]}" + ) + if input_shape[1] != x_shape[1]: + if input_shape[1] != 1: + raise ValueError( + f"When x's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {input_shape[1]}" + ) + if input_shape[2] != y_shape[2]: + if input_shape[2] != 1: + raise ValueError( + f"When y's dimension[2] is not equal with input's dimension[2], input's dimension[2] must be 1 but got {input_shape[2]}" + ) + else: + raise ValueError( + f"The dimension of input should be 3 but received input's shape: {input_shape}" + ) + + if in_dynamic_or_pir_mode(): + return _C_ops.baddbmm(input, x, y, beta, alpha) + else: + inputs = {'Input': input, "X": x, "Y": y} + attrs = {'Alpha': alpha, 'Beta': beta} + + helper = LayerHelper("baddbmm", **locals()) + check_variable_and_dtype( + input, + 'Input', + ['float16', 'float32', 'float64', 'uint16'], + 'baddbmm', + ) + check_variable_and_dtype( + x, 'X', ['float16', 'float32', 'float64', 'uint16'], 'baddbmm' + ) + check_variable_and_dtype( + y, 'Y', ['float16', 'float32', 'float64', 'uint16'], 'baddbmm' + ) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="baddbmm", inputs=inputs, attrs=attrs, outputs={"Out": out} + ) + return out + + +@inplace_apis_in_dygraph_only +def baddbmm_( + input: Tensor, + x: Tensor, + y: Tensor, + beta: float = 1.0, + alpha: float = 1.0, + name: str | None = None, +) -> Tensor: + """ + Inplace version of ``baddbmm`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_baddbmm`. + """ + input_shape = input.shape + x_shape = x.shape + y_shape = y.shape + if not len(x_shape) == len(y_shape) == 3: + raise ValueError( + f"The dimension of x, y should be 3 but receive x's shape: {x_shape}, y's shape: {y_shape}" + ) + if x_shape[2] != y_shape[1]: + raise ValueError( + f"The input Variable x's width must be equal with Variable y's height. But received x's shape = {x_shape}, y's shape = {y_shape}." + ) + + if len(input_shape) == 3: + if input_shape[0] != x_shape[0]: + raise ValueError( + f"The batch size of input must be equal to the batch size of x. But received input's batch size = {input_shape[0]}, x's batch size = {x_shape[0]}" + ) + if input_shape[1] != x_shape[1]: + if input_shape[1] != 1: + raise ValueError( + f"When x's dimension[1] is not equal with input's dimension[1], input's dimension[1] must be 1 but got {input_shape[1]}" + ) + if input_shape[2] != y_shape[2]: + if input_shape[2] != 1: + raise ValueError( + f"When y's dimension[2] is not equal with input's dimension[2], input's dimension[2] must be 1 but got {input_shape[2]}" + ) + else: + raise ValueError( + f"The dimension of input should be 3 but received input's shape: {input_shape}" + ) + + if in_dynamic_mode(): + return _C_ops.baddbmm_(input, x, y, beta, alpha) + + def renorm(x: Tensor, p: float, axis: int, max_norm: float) -> Tensor: """ **renorm** diff --git a/test/legacy_test/test_baddbmm_op.py b/test/legacy_test/test_baddbmm_op.py new file mode 100644 index 00000000000000..0be1d00141c09a --- /dev/null +++ b/test/legacy_test/test_baddbmm_op.py @@ -0,0 +1,93 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language go verning permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest + +import paddle + + +class TestBaddBmmOp(OpTest): + # test basic + def setUp(self): + self.op_type = "baddbmm" + self.prim_op_type = "comp" + self.python_api = paddle.baddbmm + self.public_python_api = paddle.baddbmm + self.init_dtype_type() + self.inputs = { + 'Input': np.random.random((10, 20, 15)).astype(self.dtype), + 'X': np.random.random((10, 20, 10)).astype(self.dtype), + 'Y': np.random.random((10, 10, 15)).astype(self.dtype), + } + self.outputs = { + 'Out': self.inputs['Input'] + + np.matmul(self.inputs['X'], self.inputs['Y']) + } + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output(check_pir=True, check_prim_pir=True) + + def test_check_grad_normal(self): + self.check_grad( + ['Input', 'X', 'Y'], + 'Out', + check_pir=True, + check_prim_pir=True, + ) + + def test_check_grad_x(self): + self.check_grad( + ['X'], + 'Out', + no_grad_set=None, + check_pir=True, + check_prim_pir=True, + ) + + def test_check_grad_y(self): + self.check_grad( + ['Y'], + 'Out', + no_grad_set=None, + check_pir=True, + check_prim_pir=True, + ) + + def test_check_grad_input(self): + self.check_grad( + ['Input'], + 'Out', + no_grad_set=None, + check_pir=True, + check_prim_pir=True, + ) + + +class TestBaddBmmFP16Op(TestBaddBmmOp): + def init_dtype_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()