diff --git a/paddle/phi/backends/dynload/cusolver.h b/paddle/phi/backends/dynload/cusolver.h index 74c64085ea7210..adbc5cdf0b6e92 100644 --- a/paddle/phi/backends/dynload/cusolver.h +++ b/paddle/phi/backends/dynload/cusolver.h @@ -65,6 +65,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); #if CUDA_VERSION >= 9020 #define CUSOLVER_ROUTINE_EACH_R1(__macro) \ + __macro(cusolverDnSgetrs); \ + __macro(cusolverDnDgetrs); \ __macro(cusolverDnSpotrfBatched); \ __macro(cusolverDnDpotrfBatched); \ __macro(cusolverDnSpotrsBatched); \ diff --git a/paddle/phi/backends/dynload/lapack.h b/paddle/phi/backends/dynload/lapack.h index eaea6783824abc..17091a904a126b 100644 --- a/paddle/phi/backends/dynload/lapack.h +++ b/paddle/phi/backends/dynload/lapack.h @@ -29,6 +29,26 @@ extern "C" void dgetrf_( extern "C" void sgetrf_( int *m, int *n, float *a, int *lda, int *ipiv, int *info); +// getrs_ +extern "C" void sgetrs_(char *trans, + int *n, + int *nrhs, + float *a, + int *lda, + int *ipiv, + float *b, + int *ldb, + int *info); +extern "C" void dgetrs_(char *trans, + int *n, + int *nrhs, + double *a, + int *lda, + int *ipiv, + double *b, + int *ldb, + int *info); + // evd extern "C" void zheevd_(char *jobz, char *uplo, @@ -339,6 +359,8 @@ extern void *lapack_dso_handle; #define LAPACK_ROUTINE_EACH(__macro) \ __macro(dgetrf_); \ __macro(sgetrf_); \ + __macro(sgetrs_); \ + __macro(dgetrs_); \ __macro(zheevd_); \ __macro(cheevd_); \ __macro(dsyevd_); \ diff --git a/paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc new file mode 100644 index 00000000000000..bf832c5e67e681 --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/lu_solve_grad_kernel.h" +#include "paddle/phi/kernels/lu_solve_kernel.h" + +namespace phi { + +template +void LuSolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::string& trans, + DenseTensor* x_grad, + DenseTensor* lu_grad) { + // Allocate memory for x_grad + dev_ctx.template Alloc(x_grad); + + // Use the forward kernel to compute the gradient + LuSolveKernel(dev_ctx, out_grad, lu, pivots, trans, x_grad); +} + +} // namespace phi + +// Register the CPU backward kernel +PD_REGISTER_KERNEL( + lu_solve_grad, CPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/lu_solve_kernel.cc b/paddle/phi/kernels/cpu/lu_solve_kernel.cc new file mode 100644 index 00000000000000..f006705cd0df0b --- /dev/null +++ b/paddle/phi/kernels/cpu/lu_solve_kernel.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/kernels/lu_solve_kernel.h" +#include "paddle/phi/kernels/impl/lu_kernel_impl.h" + +namespace phi { + +template +void LuSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const std::string& trans, + DenseTensor* out) { + // Get lu matrix dimensions + auto lu_dims = lu.dims(); + // Get x matrix dimensions + auto x_dims = x.dims(); + + // Allocate output tensor + dev_ctx.template Alloc(out); + // Copy RHS data to output (will be overwritten with solution) + // phi::Copy(dev_ctx, x, x.place(), false, out); + *out = Transpose2DTo6D(dev_ctx, x); + + // Prepare LAPACK parameters + char trans_char = (trans == "N") ? 'N' : ((trans == "T") ? 'T' : 'C'); + int n_int = lu_dims[lu_dims.size() - 1]; + int nrhs_int = x_dims[x_dims.size() - 1]; + int lda = std::max(1, n_int); // Leading dimension of A (LU matrix) + int ldb = std::max(1, n_int); // Leading dimension of B (RHS/solution matrix) + int info = 0; + + auto outdims = out->dims(); + auto outrank = outdims.size(); + int batchsize = product(common::slice_ddim(outdims, 0, outrank - 2)); + auto out_data = out->data(); + auto lu_data = reinterpret_cast(const_cast(lu.data())); + auto pivots_data = reinterpret_cast(const_cast(pivots.data())); + + for (int i = 0; i < batchsize; i++) { + auto* out_data_item = &out_data[i * ldb * nrhs_int]; + auto* lu_data_item = &lu_data[i * lda * n_int]; + auto* pivots_data_item = &pivots_data[i * n_int]; + phi::funcs::lapackLuSolve( + trans_char, + n_int, + nrhs_int, + lu_data_item, + lda, + pivots_data_item, + out_data_item, + ldb, + &info); + PADDLE_ENFORCE_EQ( + info, + 0, + phi::errors::PreconditionNotMet( + "LU solve failed with error code %d. Check if matrix is singular.", + info)); + } + *out = Transpose2DTo6D(dev_ctx, *out); +} +} // namespace phi + +PD_REGISTER_KERNEL( + lu_solve, CPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {} \ No newline at end of file diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.cc b/paddle/phi/kernels/funcs/lapack/lapack_function.cc index ebfd53291c36fa..fc07f314446d92 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.cc +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.cc @@ -30,6 +30,33 @@ void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); } +// lu_solve +template <> +void lapackLuSolve(char trans, + int n, + int nrhs, + double* a, + int lda, + int* ipiv, + double* b, + int ldb, + int* info) { + dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); +} + +template <> +void lapackLuSolve(char trans, + int n, + int nrhs, + float* a, + int lda, + int* ipiv, + float* b, + int ldb, + int* info) { + dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); +} + // eigh template <> void lapackEigh(char jobz, diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.h b/paddle/phi/kernels/funcs/lapack/lapack_function.h index d251095bb79f06..e54792e1c5bb27 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.h +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.h @@ -21,6 +21,18 @@ namespace funcs { template void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); +// Lu_solve +template +void lapackLuSolve(char trans, + int n, + int nrhs, + T *a, + int lda, + int *ipiv, + T *b, + int ldb, + int *info); + // Eigh template void lapackEigh(char jobz, diff --git a/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu new file mode 100644 index 00000000000000..a1ef7138085415 --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu @@ -0,0 +1,44 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/lu_solve_grad_kernel.h" +#include "paddle/phi/kernels/lu_solve_kernel.h" + +namespace phi { + +template +void LuSolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::string& trans, + DenseTensor* x_grad, + DenseTensor* lu_grad) { + // Allocate memory for x_grad + dev_ctx.template Alloc(x_grad); + + // Use the forward kernel to compute the gradient + LuSolveKernel(dev_ctx, out_grad, lu, pivots, trans, x_grad); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/lu_solve_kernel.cu b/paddle/phi/kernels/gpu/lu_solve_kernel.cu new file mode 100644 index 00000000000000..e7d16430bd4905 --- /dev/null +++ b/paddle/phi/kernels/gpu/lu_solve_kernel.cu @@ -0,0 +1,154 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/kernels/lu_solve_kernel.h" + +namespace phi { + +template +void cusolver_getrs(const cusolverDnHandle_t& cusolverH, + cublasOperation_t trans, + int n, + int nrhs, + T *a, + int lda, + int *ipiv, + T *b, + int ldb, + int *info); + +template <> +void cusolver_getrs(const cusolverDnHandle_t& cusolverH, + cublasOperation_t trans, + int n, + int nrhs, + float *a, + int lda, + int *ipiv, + float *b, + int ldb, + int *info) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgetrs( + cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info)); +} + +template <> +void cusolver_getrs(const cusolverDnHandle_t& cusolverH, + cublasOperation_t trans, + int n, + int nrhs, + double *a, + int lda, + int *ipiv, + double *b, + int ldb, + int *info) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrs( + cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info)); +} + +template +void lu_solve_kernel(const Context& dev_ctx, + cublasOperation_t trans, + int n, + int nrhs, + T *a, + int lda, + int *ipiv, + T *b, + int ldb, + int *info) { + /* step 1: get cusolver handle*/ + auto cusolverH = dev_ctx.cusolver_dn_handle(); + /* step 2: LU_SOLVE factorization */ + cusolver_getrs(cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info); + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +} + + +template +void LuSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const std::string& trans, + DenseTensor* out) { + dev_ctx.template Alloc(out); + // Copy x to out since cusolverDn*getrs overwrites the input + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + + // Validate input dimensions + auto x_dims = x.dims(); + auto lu_dims = lu.dims(); + + cublasOperation_t trans_op; + if (trans == "N") { + trans_op = CUBLAS_OP_N; + } else if (trans == "T") { + trans_op = CUBLAS_OP_T; + } else if (trans == "C") { + trans_op = CUBLAS_OP_C; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "trans must be one of ['N', 'T', 'C'], but got %s", trans)); + } + int n = static_cast(lu_dims[lu_dims.size() - 1]); + int nrhs = static_cast(x_dims[x_dims.size() - 1]); + int lda = std::max(1, n); + int ldb = std::max(1, n); + + DenseTensor info_tensor; + info_tensor.Resize({1}); + dev_ctx.template Alloc(&info_tensor); + int* d_info = info_tensor.data(); + + auto outdims = out->dims(); + auto outrank = outdims.size(); + int batchsize = product(common::slice_ddim(outdims, 0, outrank - 2)); + auto out_data = out->data(); + auto lu_data = reinterpret_cast(const_cast(lu.data())); + auto pivots_data = reinterpret_cast(const_cast(pivots.data())); + for (int i = 0; i < batchsize; i++) { + auto* out_data_item = &out_data[i * n * n]; + auto* lu_data_item = &lu_data[i * n * n]; + auto* pivots_data_item = &pivots_data[i * n]; + lu_solve_kernel(dev_ctx, + trans_op, + n, + nrhs, + lu_data_item, + lda, + pivots_data_item, + out_data_item, + ldb, + d_info); + } + // Synchronize to ensure the solve is complete + dev_ctx.Wait(); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + lu_solve, GPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {} + +#endif // not PADDLE_WITH_HIP \ No newline at end of file diff --git a/paddle/phi/kernels/lu_solve_grad_kernel.h b/paddle/phi/kernels/lu_solve_grad_kernel.h new file mode 100644 index 00000000000000..8d95d44b4b1030 --- /dev/null +++ b/paddle/phi/kernels/lu_solve_grad_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LuSolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const DenseTensor& out, + const DenseTensor& out_grad, + const std::string& trans, + DenseTensor* x_grad, + DenseTensor* lu_grad); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/kernels/lu_solve_kernel.h b/paddle/phi/kernels/lu_solve_kernel.h new file mode 100644 index 00000000000000..e447ed38b8bcf0 --- /dev/null +++ b/paddle/phi/kernels/lu_solve_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LuSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& lu, + const DenseTensor& pivots, + const std::string& trans, + DenseTensor* out); + +} // namespace phi \ No newline at end of file diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 2f15df6ecbbcbf..157e042097e687 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1993,6 +1993,17 @@ func : lu_grad inplace : (out_grad -> x_grad) +- backward_op : lu_solve_grad + forward : lu_solve (Tensor x, Tensor lu, Tensor pivots, str trans = "N") -> Tensor(out) + args : (Tensor x, Tensor lu, Tensor pivots, Tensor out, Tensor out_grad, str trans = "N") + output : Tensor(x_grad), Tensor(lu_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, lu] + kernel : + func : lu_solve_grad + data_type : x + - backward_op : lu_unpack_grad forward : lu_unpack (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true) -> Tensor(pmat), Tensor(l), Tensor(u) args : (Tensor x, Tensor y, Tensor l, Tensor u, Tensor pmat, Tensor l_grad, Tensor u_grad, bool unpack_ludata, bool unpack_pivots) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index d77daee9429cd1..326b7353c273ae 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3148,6 +3148,17 @@ backward : lu_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : lu_solve + args : (Tensor x, Tensor lu, Tensor pivots, str trans = "N") + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : lu_solve + data_type : x + backward : lu_solve_grad + - op : lu_unpack args : (Tensor x, Tensor y, bool unpack_ludata = true, bool unpack_pivots = true) output : Tensor(pmat), Tensor(l), Tensor(u) diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 06e978572a26ee..a37d4d08f54ae6 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -31,6 +31,7 @@ householder_product, lstsq, lu, + lu_solve, lu_unpack, matrix_exp, matrix_norm, diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ff8eed1aa17159..bb2df0d16ab00b 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -87,6 +87,7 @@ householder_product, lstsq, lu, + lu_solve, lu_unpack, matmul, matrix_power, diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index de3296efb7d6ff..5547a7cdfbe32f 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -3575,6 +3575,59 @@ def lu( return lu, p +def lu_solve(b: Tensor, lu_data: Tensor, pivots: Tensor, name=None) -> Tensor: + r""" + Computes the solution y to the system of linear equations :math:`Ay = b` , + given LU decomposition :math:`A` and column vector :math:`b`. + Args: + b (Tensor): Column vector `b` in the above equation. It has shape :math:`(*, m, k)`, + where :math:`*` is batch dimensions, with data type float32, float16. + lu_data (Tensor): LU decomposition. It has shape :math:`(*, m, m)`, where :math:`*` is batch + dimensions, that can be decomposed into an upper triangular matrix U and a lower triangular + matrix L, with data type float32, float16. + pivots (Tensor): Permutation matrix P of LU decomposition. It has + shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted + to a permutation matrix P, with data type int32. + Returns: + Tensor, the same data type as the `b` and `lu_data`. + Raises: + TypeError: If dtype of `b` or `lu_data` is not one of: float32, float16. + TypeError: If dtype of `pivots` is not: int32. + TypeError: If dtype of `b` is not same as dtype of `lu_data`. + ValueError: If the batch dimensions of pivots does not match the batch dimensions of lu_data. + ValueError: If `b` dimension less than 2, `lu_data` dimension less than 2 or `pivots` dimension less than 1. + Supported Platforms: + ``GPU`` ``CPU`` + Examples: + >>> import paddle + >>> import numpy as np + >>> b = paddle.to_tensor(np.array([[1], [3], [3]]), paddle.float32) + >>> LU_data = paddle.to_tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), paddle.float32) + >>> LU_pivots = paddle.to_tensor(np.array([2, 2, 3]), paddle.int32) + >>> y = paddle.lu_solve(b, LU_data, LU_pivots) + >>> print(y) + [[ 1.9000002] + [-1.4000001] + [ 0.6 ]] + """ + b = b if b.shape[:-2] == lu_data.shape[:-2] else paddle.broadcast_to(b, b.shape[:-2] + lu_data.shape[-2:]) + if in_dynamic_or_pir_mode(): + out = _C_ops.lu_solve(b, lu_data, pivots, 'N') + else: + check_variable_and_dtype(b, 'dtype', ['float32', 'float64'], 'lu_solve') + check_variable_and_dtype(lu_data, 'dtype', ['float32', 'float64'], 'lu_solve') + check_variable_and_dtype(pivots, 'dtype', ['int32'], 'lu_solve') + helper = LayerHelper('lu_solve', **locals()) + out = helper.create_variable_for_type_inference(dtype=b.dtype) + helper.append_op( + type='lu_solve', + inputs={'X': b, 'Lu': lu_data, 'Pivots': pivots}, + attrs={'trans': 'N'}, + outputs={'Out': out} + ) + return out + + def lu_unpack( x: Tensor, y: Tensor, diff --git a/test/legacy_test/test_lu_solve_op.py b/test/legacy_test/test_lu_solve_op.py new file mode 100644 index 00000000000000..6ee063c7b79b81 --- /dev/null +++ b/test/legacy_test/test_lu_solve_op.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import scipy.linalg +from op_test import OpTest + +import paddle +from paddle import base +from paddle.base import core + +class TestLuSolveAPI(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static(base.CPUPlace()) + A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) + b = np.array([1, 1, 1, 1]) + lu, piv = scipy.linalg.lu_factor(A) + x = scipy.linalg.lu_solve((lu, piv), b) + + lu_pd = paddle.to_tensor(lu, dtype='float32') + piv_pd = paddle.to_tensor(piv, dtype='int32') + b_pd = paddle.to_tensor(b, dtype='float32') + x_pd = paddle.linalg.lu_solve(b_pd, lu_pd, piv_pd) + x_np = x_pd.numpy() + assert np.allclose(x, x_np) + +if __name__ == "__main__": + unittest.main()