diff --git a/include/albatross/src/linalg/block_diagonal.hpp b/include/albatross/src/linalg/block_diagonal.hpp index 21c9ea10..a03dbea1 100644 --- a/include/albatross/src/linalg/block_diagonal.hpp +++ b/include/albatross/src/linalg/block_diagonal.hpp @@ -37,10 +37,9 @@ struct BlockDiagonalLDLT { solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, ThreadPool *pool) const; - template - Eigen::Matrix<_Scalar, _Rows, _Cols> - sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, - ThreadPool *pool) const; + template + Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase &rhs, + ThreadPool *pool) const; BlockDiagonal sqrt_transpose() const; @@ -51,6 +50,8 @@ struct BlockDiagonalLDLT { Eigen::Index rows() const; Eigen::Index cols() const; + + bool operator==(const BlockDiagonalLDLT &other) const; }; struct BlockDiagonal { @@ -141,15 +142,16 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, return output; } -template -inline Eigen::Matrix<_Scalar, _Rows, _Cols> -BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, +template +inline Eigen::MatrixXd +BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase &rhs, ThreadPool *pool) const { ALBATROSS_ASSERT(cols() == rhs.rows()); - Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols()); + Eigen::MatrixXd output(rows(), rhs.cols()); auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) { - const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols()); + const auto rhs_chunk = + rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols()); output.block(row, 0, blocks[i].rows(), rhs.cols()) = blocks[i].sqrt_solve(rhs_chunk); }; @@ -182,6 +184,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const { return n; } +inline bool +BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const { + return blocks == other.blocks; +} /* * Block Diagonal */