diff --git a/ChangeLog b/ChangeLog index 7137a57d..0e3e1c90 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,10 +1,14 @@ -2024-11-07 Dirk Eddelbuettel +2024-11-14 Dirk Eddelbuettel - * DESCRIPTION (Version, Date): RcppArmadillo 12.1.99-1 + * DESCRIPTION (Version, Date): RcppArmadillo 12.1.99-2 * inst/NEWS.Rd: Idem * configure.ac: Idem * configure: Idem + * inst/include/armadillo_bits/: Re-sync Armadillo 12.2.0-rc1 + +2024-11-07 Dirk Eddelbuettel + * DESCRIPTION (Depends): Increase to Rcpp (>= 1.0.12) as it supplies the required printf format change more current R versions need diff --git a/DESCRIPTION b/DESCRIPTION index dfe1e281..563daff7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: RcppArmadillo Type: Package Title: 'Rcpp' Integration for the 'Armadillo' Templated Linear Algebra Library -Version: 14.1.99-1 -Date: 2024-11-07 +Version: 14.1.99-2 +Date: 2024-11-14 Authors@R: c(person("Dirk", "Eddelbuettel", role = c("aut", "cre"), email = "edd@debian.org", comment = c(ORCID = "0000-0001-6419-907X")), person("Romain", "Francois", role = "aut", diff --git a/inst/NEWS.Rd b/inst/NEWS.Rd index d4b6e5f3..f6dd5c8b 100644 --- a/inst/NEWS.Rd +++ b/inst/NEWS.Rd @@ -3,7 +3,7 @@ \newcommand{\ghpr}{\href{https://github.com/RcppCore/RcppArmadillo/pull/#1}{##1}} \newcommand{\ghit}{\href{https://github.com/RcppCore/RcppArmadillo/issues/#1}{##1}} -\section{Changes in RcppArmadillo version 14.1.99-1 (2024-11-07)}{ +\section{Changes in RcppArmadillo version 14.1.99-2 (2024-11-14)}{ \itemize{ \item Upgraded to Armadillo release 14.2.0-rc0 (Stochastic Parrot) \itemize{ @@ -11,6 +11,9 @@ \code{rcond()} \item Faster handling of hermitian matrices by \code{inv()}, \code{rcond()}, \code{cond()}, \code{pinv()}, \code{rank()} + \item Added \code{solve_opts::force_sym} option to \code{solve()} to + force the use of the symmetric solver + \item More efficient handling of compound expressions by \code{solve()} } \item Added exporter specialisation for \code{icube} for the \code{ARMA_64BIT_WORD} case diff --git a/inst/include/armadillo_bits/CubeToMatOp_bones.hpp b/inst/include/armadillo_bits/CubeToMatOp_bones.hpp index cd2ba599..a53dd2d0 100644 --- a/inst/include/armadillo_bits/CubeToMatOp_bones.hpp +++ b/inst/include/armadillo_bits/CubeToMatOp_bones.hpp @@ -36,6 +36,9 @@ class CubeToMatOp : public Base< typename T1::elem_type, CubeToMatOp + constexpr bool is_alias(const Mat&) const { return false; } + static constexpr bool is_row = op_type::template traits::is_row; static constexpr bool is_col = op_type::template traits::is_col; static constexpr bool is_xvec = op_type::template traits::is_xvec; diff --git a/inst/include/armadillo_bits/Gen_bones.hpp b/inst/include/armadillo_bits/Gen_bones.hpp index 172e5b9c..352bfcdf 100644 --- a/inst/include/armadillo_bits/Gen_bones.hpp +++ b/inst/include/armadillo_bits/Gen_bones.hpp @@ -54,6 +54,9 @@ class Gen inline void apply_inplace_div (Mat& out) const; inline void apply(subview& out) const; + + template + constexpr bool is_alias(const Mat&) const { return false; } }; diff --git a/inst/include/armadillo_bits/Glue_bones.hpp b/inst/include/armadillo_bits/Glue_bones.hpp index 197ae746..0b4c73f6 100644 --- a/inst/include/armadillo_bits/Glue_bones.hpp +++ b/inst/include/armadillo_bits/Glue_bones.hpp @@ -56,6 +56,9 @@ class Glue inline Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword); inline ~Glue(); + template + inline bool is_alias(const Mat& X) const; + const T1& A; //!< first operand; must be derived from Base const T2& B; //!< second operand; must be derived from Base uword aux_uword; //!< storage of auxiliary data, uword format diff --git a/inst/include/armadillo_bits/Glue_meat.hpp b/inst/include/armadillo_bits/Glue_meat.hpp index cf4cfc68..66834a17 100644 --- a/inst/include/armadillo_bits/Glue_meat.hpp +++ b/inst/include/armadillo_bits/Glue_meat.hpp @@ -53,4 +53,17 @@ Glue::~Glue() +template +template +inline +bool +Glue::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (A.is_alias(X) || B.is_alias(X)); + } + + + //! @} diff --git a/inst/include/armadillo_bits/Mat_bones.hpp b/inst/include/armadillo_bits/Mat_bones.hpp index 957caaa6..69d92cf8 100644 --- a/inst/include/armadillo_bits/Mat_bones.hpp +++ b/inst/include/armadillo_bits/Mat_bones.hpp @@ -771,6 +771,9 @@ class Mat : public Base< eT, Mat > inline void steal_mem_col(Mat& X, const uword max_n_rows); + template + arma_inline bool is_alias(const Mat& X) const; //!< don't use this unless you're writing code internal to Armadillo + template class fixed; diff --git a/inst/include/armadillo_bits/Mat_meat.hpp b/inst/include/armadillo_bits/Mat_meat.hpp index 0f785d5c..51524abd 100644 --- a/inst/include/armadillo_bits/Mat_meat.hpp +++ b/inst/include/armadillo_bits/Mat_meat.hpp @@ -1324,6 +1324,19 @@ Mat::steal_mem_col(Mat& x, const uword max_n_rows) +template +template +arma_inline +bool +Mat::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (is_same_type::yes) && (void_ptr(this) == void_ptr(&X)); + } + + + //! construct a matrix from a given auxiliary array of eTs. //! if copy_aux_mem is true, new memory is allocated and the array is copied. //! if copy_aux_mem is false, the auxiliary array is used directly (without allocating memory and copying). diff --git a/inst/include/armadillo_bits/Op_bones.hpp b/inst/include/armadillo_bits/Op_bones.hpp index fa8c3efd..7fc4088f 100644 --- a/inst/include/armadillo_bits/Op_bones.hpp +++ b/inst/include/armadillo_bits/Op_bones.hpp @@ -58,6 +58,9 @@ class Op inline Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline ~Op(); + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& m; //!< the operand; must be derived from Base arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 arma_aligned uword aux_uword_a; //!< auxiliary data, uword format diff --git a/inst/include/armadillo_bits/Op_meat.hpp b/inst/include/armadillo_bits/Op_meat.hpp index 66fbaba6..879f5682 100644 --- a/inst/include/armadillo_bits/Op_meat.hpp +++ b/inst/include/armadillo_bits/Op_meat.hpp @@ -76,4 +76,17 @@ Op::~Op() +template +template +inline +bool +Op::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/inst/include/armadillo_bits/Proxy.hpp b/inst/include/armadillo_bits/Proxy.hpp index a51580dd..441f68ab 100644 --- a/inst/include/armadillo_bits/Proxy.hpp +++ b/inst/include/armadillo_bits/Proxy.hpp @@ -188,7 +188,7 @@ struct Proxy< Mat > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -235,7 +235,7 @@ struct Proxy< Col > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -282,7 +282,7 @@ struct Proxy< Row > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&Q) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&Q) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } @@ -1013,7 +1013,7 @@ struct Proxy< subview > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1060,7 +1060,7 @@ struct Proxy< subview_col > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1109,7 +1109,7 @@ struct Proxy< subview_cols > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(sv.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return sv.check_overlap(X); } @@ -1156,7 +1156,7 @@ struct Proxy< subview_row > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } @@ -1304,7 +1304,7 @@ struct Proxy< diagview > arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(Q.m)) == void_ptr(&X)) : false; } + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::yes) && (void_ptr(&(Q.m)) == void_ptr(&X)); } template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } diff --git a/inst/include/armadillo_bits/SpToDGlue_bones.hpp b/inst/include/armadillo_bits/SpToDGlue_bones.hpp index 158dd6b8..e21b3127 100644 --- a/inst/include/armadillo_bits/SpToDGlue_bones.hpp +++ b/inst/include/armadillo_bits/SpToDGlue_bones.hpp @@ -36,6 +36,9 @@ class SpToDGlue : public Base< typename T1::elem_type, SpToDGlue + constexpr bool is_alias(const Mat&) const { return false; } + const T1& A; //!< first operand; must be derived from Base or SpBase const T2& B; //!< second operand; must be derived from Base or SpBase }; diff --git a/inst/include/armadillo_bits/SpToDOp_bones.hpp b/inst/include/armadillo_bits/SpToDOp_bones.hpp index 44fa7895..2215c97a 100644 --- a/inst/include/armadillo_bits/SpToDOp_bones.hpp +++ b/inst/include/armadillo_bits/SpToDOp_bones.hpp @@ -39,6 +39,9 @@ class SpToDOp : public Base< typename T1::elem_type, SpToDOp > inline SpToDOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline ~SpToDOp(); + template + constexpr bool is_alias(const Mat&) const { return false; } + arma_aligned const T1& m; //!< the operand; must be derived from SpBase arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 arma_aligned uword aux_uword_a; //!< auxiliary data, uword format diff --git a/inst/include/armadillo_bits/arma_forward.hpp b/inst/include/armadillo_bits/arma_forward.hpp index b35d64ef..b56d4e4d 100644 --- a/inst/include/armadillo_bits/arma_forward.hpp +++ b/inst/include/armadillo_bits/arma_forward.hpp @@ -91,6 +91,8 @@ class op_diagmat; class op_trimat; class op_vectorise_row; class op_vectorise_col; +class op_symmatu; +class op_symmatl; class op_row_as_mat; class op_col_as_mat; diff --git a/inst/include/armadillo_bits/arma_version.hpp b/inst/include/armadillo_bits/arma_version.hpp index 7a32dfbd..bf3b8211 100644 --- a/inst/include/armadillo_bits/arma_version.hpp +++ b/inst/include/armadillo_bits/arma_version.hpp @@ -23,8 +23,8 @@ #define ARMA_VERSION_MAJOR 14 #define ARMA_VERSION_MINOR 1 -#define ARMA_VERSION_PATCH 90 -#define ARMA_VERSION_NAME "unstable" +#define ARMA_VERSION_PATCH 91 +#define ARMA_VERSION_NAME "14.2-RC1" diff --git a/inst/include/armadillo_bits/auxlib_bones.hpp b/inst/include/armadillo_bits/auxlib_bones.hpp index 2fbd9d52..f68ba1f5 100644 --- a/inst/include/armadillo_bits/auxlib_bones.hpp +++ b/inst/include/armadillo_bits/auxlib_bones.hpp @@ -281,6 +281,20 @@ class auxlib // + template + inline static bool solve_sym_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_sym_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex, T1 >& B_expr); + + template + inline static bool solve_sym_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + template + inline static bool solve_sym_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr); + + // + template inline static bool solve_sympd_fast(Mat& out, Mat& A, const Base& B_expr); diff --git a/inst/include/armadillo_bits/auxlib_meat.hpp b/inst/include/armadillo_bits/auxlib_meat.hpp index a6a1f122..7546b641 100644 --- a/inst/include/armadillo_bits/auxlib_meat.hpp +++ b/inst/include/armadillo_bits/auxlib_meat.hpp @@ -322,8 +322,9 @@ auxlib::inv_sym(Mat< std::complex >& A) #if defined(ARMA_CRIPPLED_LAPACK) { - arma_ignore(A); - return false; + arma_debug_print("auxlib::inv_sym(): redirecting to auxlib::inv() due to crippled LAPACK"); + + return auxlib::inv(A); } #elif defined(ARMA_USE_LAPACK) { @@ -408,17 +409,20 @@ auxlib::inv_sym_rcond(Mat& A, eT& out_rcond) podarray ipiv(A.n_rows); podarray iwork(A.n_rows); - eT work_query[2] = {}; - blas_int lwork_query = -1; - - arma_debug_print("lapack::sytrf()"); - lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); - - if(info != 0) { return false; } - - blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); - - lwork = (std::max)(lwork_proposed, lwork); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } podarray work( static_cast(lwork) ); @@ -475,9 +479,9 @@ auxlib::inv_sym_rcond(Mat< std::complex >& A, T& out_rcond) #if defined(ARMA_CRIPPLED_LAPACK) { - arma_ignore(A); - arma_ignore(out_rcond); - return false; + arma_debug_print("auxlib::inv_sym_rcond(): redirecting to auxlib::inv_rcond() due to crippled LAPACK"); + + return auxlib::inv_rcond(A, out_rcond); } #elif defined(ARMA_USE_LAPACK) { @@ -497,17 +501,20 @@ auxlib::inv_sym_rcond(Mat< std::complex >& A, T& out_rcond) podarray ipiv(A.n_rows); podarray lanhe_work(A.n_rows); - eT work_query[2] = {}; - blas_int lwork_query = -1; - - arma_debug_print("lapack::hetrf()"); - lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); - - if(info != 0) { return false; } - - blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); - - lwork = (std::max)(lwork_proposed, lwork); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } podarray work( static_cast(lwork) ); @@ -4558,6 +4565,328 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ +template +inline +bool +auxlib::solve_sym_fast(Mat& out, Mat& A, const Base& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + arma_conform_assert_blas_size(A,out); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sytrs()"); + lapack::sytrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_fast(Mat< std::complex >& out, Mat< std::complex >& A, const Base< std::complex, T1 >& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::solve_sym_fast(): redirecting to auxlib::solve_square_fast() due to crippled LAPACK"); + + return auxlib::solve_square_fast(out, A, B_expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef std::complex eT; + + arma_conform_assert_blas_size(A,out); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + + podarray ipiv(A.n_rows); + + if(n > blas_int(podarray_prealloc_n_elem::val)) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hetrs()"); + lapack::hetrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_debug_print("auxlib::solve_sym_rcond(): redirecting to auxlib::solve_square_rcond() due to crippled LAPACK"); + + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type eT; + + out_rcond = eT(0); + + arma_conform_assert_blas_size(A,out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::sycon() requirements + blas_int info = 0; + eT norm_val = eT(0); + eT tmp_rcond = eT(0); + + podarray ipiv(A.n_rows); + podarray iwork(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sytrs()"); + lapack::sytrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::sycon()"); + lapack::sycon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), iwork.memptr(), &info); + + out_rcond = tmp_rcond; + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::solve_sym_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr) + { + arma_debug_sigprint(); + + out = B_expr.get_ref(); + + const uword B_n_rows = out.n_rows; + const uword B_n_cols = out.n_cols; + + arma_conform_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + out_rcond = T(0); + + arma_conform_assert_blas_size(A,out); + + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(out.n_rows); + blas_int nrhs = blas_int(out.n_cols); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), 2*n); // 2*n due to lapack::hecon() requirements + blas_int info = 0; + T norm_val = T(0); + T tmp_rcond = T(0); + + podarray ipiv(A.n_rows); + podarray lanhe_work(A.n_rows); + + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } + + podarray work( static_cast(lwork) ); + + arma_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, lanhe_work.memptr()); + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hetrs()"); + lapack::hetrs(&uplo, &n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); + + if(info != 0) { return false; } + + arma_debug_print("lapack::hecon()"); + lapack::hecon(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &norm_val, &tmp_rcond, work.memptr(), &info); + + out_rcond = tmp_rcond; + + return (info == 0); + } + #else + { + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + template inline bool @@ -6503,17 +6832,20 @@ auxlib::rcond_sym(Mat& A) podarray ipiv(A.n_rows); podarray iwork(A.n_rows); - eT work_query[2] = {}; - blas_int lwork_query = -1; - - arma_debug_print("lapack::sytrf()"); - lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); - - if(info != 0) { return eT(0); } - - blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); - - lwork = (std::max)(lwork_proposed, lwork); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::sytrf()"); + lapack::sytrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return eT(0); } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } podarray work( static_cast(lwork) ); @@ -6574,17 +6906,20 @@ auxlib::rcond_sym(Mat< std::complex >& A) podarray ipiv(A.n_rows); podarray lanhe_work(A.n_rows); - eT work_query[2] = {}; - blas_int lwork_query = -1; - - arma_debug_print("lapack::hetrf()"); - lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); - - if(info != 0) { return T(0); } - - blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); - - lwork = (std::max)(lwork_proposed, lwork); + if( (2*n) > blas_int(podarray_prealloc_n_elem::val) ) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_debug_print("lapack::hetrf()"); + lapack::hetrf(&uplo, &n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return T(0); } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + + lwork = (std::max)(lwork_proposed, lwork); + } podarray work( static_cast(lwork) ); diff --git a/inst/include/armadillo_bits/def_lapack.hpp b/inst/include/armadillo_bits/def_lapack.hpp index 95da42fb..ad4e50c2 100644 --- a/inst/include/armadillo_bits/def_lapack.hpp +++ b/inst/include/armadillo_bits/def_lapack.hpp @@ -275,6 +275,12 @@ #define arma_chetrf chetrf #define arma_zhetrf zhetrf + #define arma_ssytrs ssytrs + #define arma_dsytrs dsytrs + + #define arma_chetrs chetrs + #define arma_zhetrs zhetrs + #define arma_ssytri ssytri #define arma_dsytri dsytri @@ -529,6 +535,12 @@ #define arma_chetrf CHETRF #define arma_zhetrf ZHETRF + #define arma_ssytrs SSYTRS + #define arma_dsytrs DSYTRS + + #define arma_chetrs CHETRS + #define arma_zhetrs ZHETRS + #define arma_ssytri SSYTRI #define arma_dsytri DSYTRI @@ -890,6 +902,14 @@ extern "C" void arma_fortran(arma_chetrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; void arma_fortran(arma_zhetrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + // solve system using pre-computed factorisation (real) + void arma_fortran(arma_ssytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (complex) + void arma_fortran(arma_chetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zhetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + // inverse of symmetric matrix using pre-computed factorisation (real) void arma_fortran(arma_ssytri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; void arma_fortran(arma_dsytri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; @@ -1238,6 +1258,14 @@ extern "C" void arma_fortran(arma_chetrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; void arma_fortran(arma_zhetrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + // solve system using pre-computed factorisation (real) + void arma_fortran(arma_ssytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsytrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + + // solve system using pre-computed factorisation (complex) + void arma_fortran(arma_zhetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_chetrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + // inverse of symmetric matrix using pre-computed factorisation (real) void arma_fortran(arma_ssytri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, blas_int* info) ARMA_NOEXCEPT; void arma_fortran(arma_dsytri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, blas_int* info) ARMA_NOEXCEPT; diff --git a/inst/include/armadillo_bits/diagview_bones.hpp b/inst/include/armadillo_bits/diagview_bones.hpp index 5aa4bcee..4117aa45 100644 --- a/inst/include/armadillo_bits/diagview_bones.hpp +++ b/inst/include/armadillo_bits/diagview_bones.hpp @@ -108,6 +108,9 @@ class diagview : public Base< eT, diagview > inline static void schur_inplace(Mat& out, const diagview& in); inline static void div_inplace(Mat& out, const diagview& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; friend class subview; diff --git a/inst/include/armadillo_bits/diagview_meat.hpp b/inst/include/armadillo_bits/diagview_meat.hpp index c8d741ab..a33b2ecb 100644 --- a/inst/include/armadillo_bits/diagview_meat.hpp +++ b/inst/include/armadillo_bits/diagview_meat.hpp @@ -236,13 +236,13 @@ diagview::operator= (const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -309,13 +309,13 @@ diagview::operator+=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -382,13 +382,13 @@ diagview::operator-=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -455,13 +455,13 @@ diagview::operator%=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -528,13 +528,13 @@ diagview::operator/=(const Base& o) "diagview: given object has incompatible size" ); - const bool is_alias = P.is_alias(d_m); + const bool have_alias = P.is_alias(d_m); - if(is_alias) { arma_debug_print("aliasing detected"); } + if(have_alias) { arma_debug_print("aliasing detected"); } - if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (is_alias) ) + if( (is_Mat::stored_type>::value) || (Proxy::use_at) || (have_alias) ) { - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& x = tmp.M; const eT* x_mem = x.memptr(); @@ -1022,4 +1022,17 @@ diagview::randn() +template +template +inline +bool +diagview::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/inst/include/armadillo_bits/eGlue_bones.hpp b/inst/include/armadillo_bits/eGlue_bones.hpp index 097dc6cb..a86377a9 100644 --- a/inst/include/armadillo_bits/eGlue_bones.hpp +++ b/inst/include/armadillo_bits/eGlue_bones.hpp @@ -44,6 +44,9 @@ class eGlue : public Base< typename T1::elem_type, eGlue > arma_inline ~eGlue(); arma_inline eGlue(const T1& in_A, const T2& in_B); + template + inline bool is_alias(const Mat& X) const; + arma_inline uword get_n_rows() const; arma_inline uword get_n_cols() const; arma_inline uword get_n_elem() const; diff --git a/inst/include/armadillo_bits/eGlue_meat.hpp b/inst/include/armadillo_bits/eGlue_meat.hpp index 04eb6ba2..4d55bc78 100644 --- a/inst/include/armadillo_bits/eGlue_meat.hpp +++ b/inst/include/armadillo_bits/eGlue_meat.hpp @@ -49,6 +49,17 @@ eGlue::eGlue(const T1& in_A, const T2& in_B) +template +template +inline +bool +eGlue::is_alias(const Mat& X) const + { + return (P1.is_alias(X) || P2.is_alias(X)); + } + + + template arma_inline uword diff --git a/inst/include/armadillo_bits/eOp_bones.hpp b/inst/include/armadillo_bits/eOp_bones.hpp index d32abddb..a200a6bb 100644 --- a/inst/include/armadillo_bits/eOp_bones.hpp +++ b/inst/include/armadillo_bits/eOp_bones.hpp @@ -50,6 +50,9 @@ class eOp : public Base< typename T1::elem_type, eOp > inline eOp(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline eOp(const T1& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); + template + inline bool is_alias(const Mat& X) const; + arma_inline uword get_n_rows() const; arma_inline uword get_n_cols() const; arma_inline uword get_n_elem() const; diff --git a/inst/include/armadillo_bits/eOp_meat.hpp b/inst/include/armadillo_bits/eOp_meat.hpp index 75dfec02..57473fe7 100644 --- a/inst/include/armadillo_bits/eOp_meat.hpp +++ b/inst/include/armadillo_bits/eOp_meat.hpp @@ -74,7 +74,20 @@ eOp::~eOp() arma_debug_sigprint(); } + + +template +template +inline +bool +eOp::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + return P.is_alias(X); + } + + template arma_inline diff --git a/inst/include/armadillo_bits/glue_solve_bones.hpp b/inst/include/armadillo_bits/glue_solve_bones.hpp index 20c01659..c04b17a6 100644 --- a/inst/include/armadillo_bits/glue_solve_bones.hpp +++ b/inst/include/armadillo_bits/glue_solve_bones.hpp @@ -140,6 +140,7 @@ namespace solve_opts static constexpr uword flag_refine = uword(1u << 9); static constexpr uword flag_no_trimat = uword(1u << 10); static constexpr uword flag_force_approx = uword(1u << 11); + static constexpr uword flag_force_sym = uword(1u << 12); struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; @@ -154,6 +155,7 @@ namespace solve_opts struct opts_refine : public opts { inline constexpr opts_refine() : opts(flag_refine ) {} }; struct opts_no_trimat : public opts { inline constexpr opts_no_trimat() : opts(flag_no_trimat ) {} }; struct opts_force_approx : public opts { inline constexpr opts_force_approx() : opts(flag_force_approx) {} }; + struct opts_force_sym : public opts { inline constexpr opts_force_sym() : opts(flag_force_sym ) {} }; static constexpr opts_none none; static constexpr opts_fast fast; @@ -168,6 +170,7 @@ namespace solve_opts static constexpr opts_refine refine; static constexpr opts_no_trimat no_trimat; static constexpr opts_force_approx force_approx; + static constexpr opts_force_sym force_sym; } diff --git a/inst/include/armadillo_bits/glue_solve_meat.hpp b/inst/include/armadillo_bits/glue_solve_meat.hpp index f50bcc40..6aeb6729 100644 --- a/inst/include/armadillo_bits/glue_solve_meat.hpp +++ b/inst/include/armadillo_bits/glue_solve_meat.hpp @@ -99,6 +99,7 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const const bool refine = has_user_flags && bool(flags & solve_opts::flag_refine ); const bool no_trimat = has_user_flags && bool(flags & solve_opts::flag_no_trimat ); const bool force_approx = has_user_flags && bool(flags & solve_opts::flag_force_approx); + const bool force_sym = has_user_flags && bool(flags & solve_opts::flag_force_sym ); if(has_user_flags) { @@ -114,10 +115,11 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const if(refine ) { arma_debug_print("refine"); } if(no_trimat ) { arma_debug_print("no_trimat"); } if(force_approx) { arma_debug_print("force_approx"); } + if(force_sym ) { arma_debug_print("force_sym"); } - arma_conform_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); - arma_conform_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); - arma_conform_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); + arma_conform_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); + arma_conform_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); + arma_conform_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); } Mat A = A_expr.get_ref(); @@ -128,26 +130,34 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const arma_conform_check( no_approx, "solve(): options 'no_approx' and 'force_approx' are mutually exclusive" ); - if(fast) { arma_warn(2, "solve(): option 'fast' ignored for forced approximate solution" ); } - if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } - if(refine) { arma_warn(2, "solve(): option 'refine' ignored for forced approximate solution" ); } - if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + if(fast) { arma_warn(2, "solve(): option 'fast' ignored for forced approximate solution" ); } + if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } + if(refine) { arma_warn(2, "solve(): option 'refine' ignored for forced approximate solution" ); } + if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + if(force_sym) { arma_warn(2, "solve(): option 'force_sym' ignored for forced approximate solution" ); } return auxlib::solve_approx_svd(actual_out, A, B_expr.get_ref()); // A is overwritten } + if(force_sym) + { + if((arma_config::check_conform) && (auxlib::rudimentary_sym_check(A) == false)) + { + if(is_cx::no ) { arma_warn(1, "solve(): option 'force_sym' enabled, but given matrix is not symmetric"); } + if(is_cx::yes) { arma_warn(1, "solve(): option 'force_sym' enabled, but given matrix is not hermitian"); } + } + + if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for forced symmetric solver" ); } + if(equilibrate) { arma_warn(2, "solve(): option 'force_sym' ignored as option 'equilibrate' is enabled (combination not implemented yet)" ); } + if(refine) { arma_warn(2, "solve(): option 'force_sym' ignored as option 'refine' is enabled (combination not implemented yet)" ); } + } + // A_expr and B_expr can be used more than once (sympd optimisation fails or approximate solution required), // so ensure they are not overwritten in case we have aliasing - bool is_alias = true; // assume we have aliasing until we can prove otherwise + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value && is_Mat::value) - { - const quasi_unwrap UA( A_expr.get_ref() ); - const quasi_unwrap UB( B_expr.get_ref() ); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_gen_full::apply(): aliasing detected"); } Mat tmp; Mat& out = (is_alias) ? tmp : actual_out; @@ -162,12 +172,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const uword KL = 0; uword KU = 0; - const bool is_band = arma_config::optimise_band && ((no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32))); + const bool is_band = arma_config::optimise_band && ( (no_band || force_sym || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32)) ); - const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || is_band ) ? false : trimat_helper::is_triu(A); - const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || is_band || is_triu) ? false : trimat_helper::is_tril(A); + const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || force_sym || is_band ) ? false : trimat_helper::is_triu(A); + const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || force_sym || is_band || is_triu) ? false : trimat_helper::is_tril(A); - const bool try_sympd = arma_config::optimise_sym && ((no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16)))); + const bool is_sym = arma_config::optimise_sym && ( (refine || equilibrate || likely_sympd || force_sym || is_band || is_triu || is_tril || auxlib::crippled_lapack(A)) ? false : is_sym_expr::eval(A_expr.get_ref()) ); + const bool try_sympd = arma_config::optimise_sym && ( ( no_sympd || is_sym || force_sym || is_band || is_triu || is_tril || auxlib::crippled_lapack(A)) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16))) ); if(fast) { @@ -201,6 +212,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_trimat_fast(out, A, B_expr.get_ref(), layout); } else + if(force_sym || is_sym) + { + arma_debug_print("glue_solve_gen_full::apply(): fast + sym"); + + status = auxlib::solve_sym_fast(out, A, B_expr.get_ref()); // A is overwritten + } + else if(try_sympd) { arma_debug_print("glue_solve_gen_full::apply(): fast + try_sympd"); @@ -238,6 +256,10 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate); } + // else + // if(force_sym || is_sym) // TODO: implement auxlib::solve_sym_refine() + // { + // } else if(try_sympd) { @@ -287,6 +309,13 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); } else + if(force_sym || is_sym) + { + arma_debug_print("glue_solve_gen_full::apply(): rcond + sym"); + + status = auxlib::solve_sym_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten + } + else if(try_sympd) { bool sympd_state = false; @@ -315,6 +344,7 @@ glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const if(equilibrate) { arma_warn(2, "solve(): option 'equilibrate' ignored for non-square matrix" ); } if(refine) { arma_warn(2, "solve(): option 'refine' ignored for non-square matrix" ); } if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for non-square matrix" ); } + if(force_sym) { arma_warn(2, "solve(): option 'force_sym' ignored for non-square matrix" ); } if(fast) { @@ -406,14 +436,9 @@ glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, co const uword layout = (triu) ? uword(0) : uword(1); - bool is_alias = true; + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value) - { - const quasi_unwrap UB(B_expr.get_ref()); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_tri_default::apply(): aliasing detected"); } T rcond = T(0); bool status = false; @@ -497,6 +522,7 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const const bool refine = bool(flags & solve_opts::flag_refine ); const bool no_trimat = bool(flags & solve_opts::flag_no_trimat ); const bool force_approx = bool(flags & solve_opts::flag_force_approx); + const bool force_sym = bool(flags & solve_opts::flag_force_sym ); arma_debug_print("glue_solve_tri_full::apply(): enabled flags:"); @@ -510,6 +536,10 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const if(refine ) { arma_debug_print("refine"); } if(no_trimat ) { arma_debug_print("no_trimat"); } if(force_approx) { arma_debug_print("force_approx"); } + if(force_sym ) { arma_debug_print("force_sym"); } + + arma_conform_check( (likely_sympd), "solve(): option 'likely_sympd' not applicable to triangular matrix" ); + arma_conform_check( (force_sym ), "solve(): option 'force_sym' not applicable to triangular matrix" ); if(no_trimat || equilibrate || refine || force_approx) { @@ -518,8 +548,6 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const return glue_solve_gen_full::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask)); } - if(likely_sympd) { arma_warn(2, "solve(): option 'likely_sympd' ignored for triangular matrix"); } - const quasi_unwrap UA(A_expr.get_ref()); const Mat& A = UA.M; @@ -527,14 +555,9 @@ glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const const uword layout = (triu) ? uword(0) : uword(1); - bool is_alias = true; + const bool is_alias = A_expr.get_ref().is_alias(actual_out) || B_expr.get_ref().is_alias(actual_out); - if(is_Mat::value) - { - const quasi_unwrap UB(B_expr.get_ref()); - - is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); - } + if(is_alias) { arma_debug_print("glue_solve_tri_full::apply(): aliasing detected"); } T rcond = T(0); bool status = false; diff --git a/inst/include/armadillo_bits/glue_times_meat.hpp b/inst/include/armadillo_bits/glue_times_meat.hpp index 4ffafaa8..92d69599 100644 --- a/inst/include/armadillo_bits/glue_times_meat.hpp +++ b/inst/include/armadillo_bits/glue_times_meat.hpp @@ -119,7 +119,9 @@ glue_times_redirect2_helper::apply(Mat& out, const arma_conform_assert_mul_size(A, B, "matrix multiplication"); - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(A) == false) && (is_sym_expr::eval(X.A) || sym_helper::is_approx_sym(A, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : ( (is_sym) ? auxlib::solve_sym_fast(out, A, B) : auxlib::solve_square_fast(out, A, B) ); if(status == false) { @@ -278,7 +280,9 @@ glue_times_redirect3_helper::apply(Mat& out, const if(is_cx::yes) { arma_warn(1, "inv_sympd(): given matrix is not hermitian"); } } - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(A) == false) && (is_sym_expr::eval(X.A.A) || sym_helper::is_approx_sym(A, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : ( (is_sym) ? auxlib::solve_sym_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC) ); if(status == false) { @@ -315,7 +319,9 @@ glue_times_redirect3_helper::apply(Mat& out, const Mat solve_result; - const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C); + const bool is_sym = (strip_inv::do_inv_spd) ? false : ( arma_config::optimise_sym && (auxlib::crippled_lapack(B) == false) && (is_sym_expr::eval(X.A.B) || sym_helper::is_approx_sym(B, uword(100))) ); + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : ( (is_sym) ? auxlib::solve_sym_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C) ); if(status == false) { diff --git a/inst/include/armadillo_bits/mtGlue_bones.hpp b/inst/include/armadillo_bits/mtGlue_bones.hpp index 5937d89f..fc15a6f9 100644 --- a/inst/include/armadillo_bits/mtGlue_bones.hpp +++ b/inst/include/armadillo_bits/mtGlue_bones.hpp @@ -37,6 +37,9 @@ class mtGlue : public Base< out_eT, mtGlue > arma_inline mtGlue(const T1& in_A, const T2& in_B, const uword in_aux_uword); arma_inline ~mtGlue(); + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& A; //!< first operand; must be derived from Base arma_aligned const T2& B; //!< second operand; must be derived from Base arma_aligned uword aux_uword; //!< storage of auxiliary data, uword format diff --git a/inst/include/armadillo_bits/mtGlue_meat.hpp b/inst/include/armadillo_bits/mtGlue_meat.hpp index 85cc9a21..4a64840c 100644 --- a/inst/include/armadillo_bits/mtGlue_meat.hpp +++ b/inst/include/armadillo_bits/mtGlue_meat.hpp @@ -53,4 +53,17 @@ mtGlue::~mtGlue() +template +template +inline +bool +mtGlue::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (A.is_alias(X) || B.is_alias(X)); + } + + + //! @} diff --git a/inst/include/armadillo_bits/mtOp_bones.hpp b/inst/include/armadillo_bits/mtOp_bones.hpp index ff0e4c38..87e74573 100644 --- a/inst/include/armadillo_bits/mtOp_bones.hpp +++ b/inst/include/armadillo_bits/mtOp_bones.hpp @@ -47,7 +47,10 @@ class mtOp : public Base< out_eT, mtOp > inline mtOp(const mtOp_dual_aux_indicator&, const T1& in_m, const in_eT in_aux_a, const out_eT in_aux_b); inline ~mtOp(); - + + template + inline bool is_alias(const Mat& X) const; + arma_aligned const T1& m; //!< the operand; must be derived from Base arma_aligned in_eT aux; //!< auxiliary data, using the element type as used by T1 diff --git a/inst/include/armadillo_bits/mtOp_meat.hpp b/inst/include/armadillo_bits/mtOp_meat.hpp index 623660c4..f53816df 100644 --- a/inst/include/armadillo_bits/mtOp_meat.hpp +++ b/inst/include/armadillo_bits/mtOp_meat.hpp @@ -101,4 +101,17 @@ mtOp::~mtOp() +template +template +inline +bool +mtOp::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + //! @} diff --git a/inst/include/armadillo_bits/op_inv_gen_meat.hpp b/inst/include/armadillo_bits/op_inv_gen_meat.hpp index 12439f67..b127eeff 100644 --- a/inst/include/armadillo_bits/op_inv_gen_meat.hpp +++ b/inst/include/armadillo_bits/op_inv_gen_meat.hpp @@ -221,7 +221,7 @@ op_inv_gen_full::apply_direct(Mat& out, const Base::eval(expr.get_ref()) || sym_helper::is_approx_sym(out, uword(100)) ) ) { arma_debug_print("op_inv_gen_full: symmetric/hermitian optimisation"); @@ -387,7 +387,7 @@ op_inv_gen_rcond::apply_direct(Mat& out, op_inv_gen_stat return auxlib::inv_tr_rcond(out, out_state.rcond, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); } - if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(out) == false) && (sym_helper::is_approx_sym(out)) ) + if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(out) == false) && ( is_sym_expr::eval(expr.get_ref()) || sym_helper::is_approx_sym(out, uword(100)) ) ) { arma_debug_print("op_inv_gen_rcond: symmetric/hermitian optimisation"); diff --git a/inst/include/armadillo_bits/op_rcond_meat.hpp b/inst/include/armadillo_bits/op_rcond_meat.hpp index c28360ad..e994e692 100644 --- a/inst/include/armadillo_bits/op_rcond_meat.hpp +++ b/inst/include/armadillo_bits/op_rcond_meat.hpp @@ -89,7 +89,7 @@ op_rcond::apply(const Base& X) return auxlib::rcond_trimat(A, layout); } - if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && (sym_helper::is_approx_sym(A)) ) + if( (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && ( is_sym_expr::eval(X.get_ref()) || sym_helper::is_approx_sym(A, uword(100)) ) ) { arma_debug_print("op_rcond::apply(): symmetric/hermitian optimisation"); diff --git a/inst/include/armadillo_bits/subview_bones.hpp b/inst/include/armadillo_bits/subview_bones.hpp index 876e5afb..53b4a421 100644 --- a/inst/include/armadillo_bits/subview_bones.hpp +++ b/inst/include/armadillo_bits/subview_bones.hpp @@ -201,6 +201,9 @@ class subview : public Base< eT, subview > inline void swap_rows(const uword in_row1, const uword in_row2); inline void swap_cols(const uword in_col1, const uword in_col2); + template + inline bool is_alias(const Mat& X) const; + class const_iterator; diff --git a/inst/include/armadillo_bits/subview_elem1_bones.hpp b/inst/include/armadillo_bits/subview_elem1_bones.hpp index 2ac3cdad..37b5d725 100644 --- a/inst/include/armadillo_bits/subview_elem1_bones.hpp +++ b/inst/include/armadillo_bits/subview_elem1_bones.hpp @@ -99,6 +99,9 @@ class subview_elem1 : public Base< eT, subview_elem1 > inline static void schur_inplace(Mat& out, const subview_elem1& in); inline static void div_inplace(Mat& out, const subview_elem1& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; friend class Cube; diff --git a/inst/include/armadillo_bits/subview_elem1_meat.hpp b/inst/include/armadillo_bits/subview_elem1_meat.hpp index 32a06caa..95592b23 100644 --- a/inst/include/armadillo_bits/subview_elem1_meat.hpp +++ b/inst/include/armadillo_bits/subview_elem1_meat.hpp @@ -233,9 +233,9 @@ subview_elem1::inplace_op(const Base& x) arma_conform_check( (aa_n_elem != P.get_n_elem()), "Mat::elem(): size mismatch" ); - const bool is_alias = P.is_alias(m); + const bool have_alias = P.is_alias(m); - if( (is_alias == false) && (Proxy::use_at == false) ) + if( (have_alias == false) && (Proxy::use_at == false) ) { typename Proxy::ea_type X = P.get_ea(); @@ -271,7 +271,7 @@ subview_elem1::inplace_op(const Base& x) { arma_debug_print("subview_elem1::inplace_op(): aliasing or use_at detected"); - const unwrap_check::stored_type> tmp(P.Q, is_alias); + const unwrap_check::stored_type> tmp(P.Q, have_alias); const Mat& M = tmp.M; const eT* X = M.memptr(); @@ -950,4 +950,17 @@ subview_elem1::div_inplace(Mat& out, const subview_elem1& in) +template +template +inline +bool +subview_elem1::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (m.is_alias(X) || a.get_ref().is_alias(X)); + } + + + //! @} diff --git a/inst/include/armadillo_bits/subview_elem2_bones.hpp b/inst/include/armadillo_bits/subview_elem2_bones.hpp index d4c4cbe6..bec561d7 100644 --- a/inst/include/armadillo_bits/subview_elem2_bones.hpp +++ b/inst/include/armadillo_bits/subview_elem2_bones.hpp @@ -103,6 +103,9 @@ class subview_elem2 : public Base< eT, subview_elem2 > inline static void schur_inplace(Mat& out, const subview_elem2& in); inline static void div_inplace(Mat& out, const subview_elem2& in); + template + inline bool is_alias(const Mat& X) const; + friend class Mat; }; diff --git a/inst/include/armadillo_bits/subview_elem2_meat.hpp b/inst/include/armadillo_bits/subview_elem2_meat.hpp index a6e9dc23..4a160642 100644 --- a/inst/include/armadillo_bits/subview_elem2_meat.hpp +++ b/inst/include/armadillo_bits/subview_elem2_meat.hpp @@ -870,4 +870,17 @@ subview_elem2::div_inplace(Mat& out, const subview_elem2& in) +template +template +inline +bool +subview_elem2::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return (m.is_alias(X) || base_ri.get_ref().is_alias(X) || base_ci.get_ref().is_alias(X)); + } + + + //! @} diff --git a/inst/include/armadillo_bits/subview_meat.hpp b/inst/include/armadillo_bits/subview_meat.hpp index 81cc39f9..021f7cb0 100644 --- a/inst/include/armadillo_bits/subview_meat.hpp +++ b/inst/include/armadillo_bits/subview_meat.hpp @@ -2617,6 +2617,19 @@ subview::swap_cols(const uword in_col1, const uword in_col2) +template +template +inline +bool +subview::is_alias(const Mat& X) const + { + arma_debug_sigprint(); + + return m.is_alias(X); + } + + + template inline typename subview::iterator diff --git a/inst/include/armadillo_bits/sym_helper.hpp b/inst/include/armadillo_bits/sym_helper.hpp index dcefb4f2..fa2b2b3a 100644 --- a/inst/include/armadillo_bits/sym_helper.hpp +++ b/inst/include/armadillo_bits/sym_helper.hpp @@ -403,6 +403,20 @@ is_approx_sym(const Mat& A) +template +inline +bool +is_approx_sym(const Mat& A, const uword min_n_rows) + { + arma_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < min_n_rows)) { return false; } + + return is_approx_sym_worker(A); + } + + + // diff --git a/inst/include/armadillo_bits/traits.hpp b/inst/include/armadillo_bits/traits.hpp index 5318ef9f..e3bf703e 100644 --- a/inst/include/armadillo_bits/traits.hpp +++ b/inst/include/armadillo_bits/traits.hpp @@ -1343,5 +1343,29 @@ struct is_sym_expr< Glue< Op, op_htrans>, Mat, glue_times > > } }; +template +struct is_sym_expr< Op > + { + static + arma_inline + bool + eval(const Op&) + { + return true; + } + }; + +template +struct is_sym_expr< Op > + { + static + arma_inline + bool + eval(const Op&) + { + return true; + } + }; + //! @} diff --git a/inst/include/armadillo_bits/translate_lapack.hpp b/inst/include/armadillo_bits/translate_lapack.hpp index a7cc9ee8..cdb3aa65 100644 --- a/inst/include/armadillo_bits/translate_lapack.hpp +++ b/inst/include/armadillo_bits/translate_lapack.hpp @@ -1345,7 +1345,7 @@ namespace lapack template inline void - sytrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, blas_int* lwork, blas_int* info) + sytrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, const blas_int* lwork, blas_int* info) { arma_type_check(( is_supported_blas_type::value == false )); @@ -1363,7 +1363,7 @@ namespace lapack template inline void - hetrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, blas_int* lwork, blas_int* info) + hetrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* ipiv, eT* work, const blas_int* lwork, blas_int* info) { arma_type_check(( is_supported_blas_type::value == false )); @@ -1378,6 +1378,42 @@ namespace lapack + template + inline + void + sytrs(const char* uplo, const blas_int* n, const blas_int* nrhs, const eT* a, const blas_int* lda, const blas_int* ipiv, eT* b, const blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_float::value) { typedef float T; arma_fortran(arma_ssytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dsytrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + #endif + } + + + + template + inline + void + hetrs(const char* uplo, const blas_int* n, const blas_int* nrhs, const eT* a, const blas_int* lda, const blas_int* ipiv, eT* b, const blas_int* ldb, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info, 1); } + #else + if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_chetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zhetrs)(uplo, n, nrhs, (T*)a, lda, ipiv, (T*)b, ldb, info); } + #endif + } + + + template inline void diff --git a/inst/include/armadillo_bits/unwrap.hpp b/inst/include/armadillo_bits/unwrap.hpp index 7e936317..b6be60bc 100644 --- a/inst/include/armadillo_bits/unwrap.hpp +++ b/inst/include/armadillo_bits/unwrap.hpp @@ -1075,8 +1075,8 @@ struct unwrap_check_mixed< Mat > template inline unwrap_check_mixed(const Mat& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Mat(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Mat(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); } @@ -1112,8 +1112,8 @@ struct unwrap_check_mixed< Row > template inline unwrap_check_mixed(const Row& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Row(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Row(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); } @@ -1150,8 +1150,8 @@ struct unwrap_check_mixed< Col > template inline unwrap_check_mixed(const Col& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Col(A) : nullptr ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? new Col(A) : nullptr ) + , M ( ((is_same_type::yes) && (void_ptr(&A) == void_ptr(&B))) ? (*M_local) : A ) { arma_debug_sigprint(); }