From 45c69b9a034b36fcd6a51a2f4148612019a9370b Mon Sep 17 00:00:00 2001 From: kleeman Date: Fri, 30 Dec 2022 15:08:25 -0800 Subject: [PATCH] add block diag methods --- .../albatross/src/linalg/block_diagonal.hpp | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/include/albatross/src/linalg/block_diagonal.hpp b/include/albatross/src/linalg/block_diagonal.hpp index 21c9ea10..8e65c1c4 100644 --- a/include/albatross/src/linalg/block_diagonal.hpp +++ b/include/albatross/src/linalg/block_diagonal.hpp @@ -37,10 +37,9 @@ struct BlockDiagonalLDLT { solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, ThreadPool *pool) const; - template - Eigen::Matrix<_Scalar, _Rows, _Cols> - sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, - ThreadPool *pool) const; + template + Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase &rhs, + ThreadPool *pool) const; BlockDiagonal sqrt_transpose() const; @@ -51,6 +50,8 @@ struct BlockDiagonalLDLT { Eigen::Index rows() const; Eigen::Index cols() const; + + bool operator==(const BlockDiagonalLDLT &other) const; }; struct BlockDiagonal { @@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, return output; } -template -inline Eigen::Matrix<_Scalar, _Rows, _Cols> -BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs, +template +inline Eigen::MatrixXd +BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase &rhs, ThreadPool *pool) const { ALBATROSS_ASSERT(cols() == rhs.rows()); - Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols()); + Eigen::MatrixXd output(rows(), rhs.cols()); auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) { - const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols()); + const auto rhs_chunk = + rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols()); output.block(row, 0, blocks[i].rows(), rhs.cols()) = blocks[i].sqrt_solve(rhs_chunk); }; - apply_map(block_to_row_map(), solve_and_fill_one_block, pool); + // Intentionally leaving pool out here due to an unknown bug + // in which the thread pool version crashes in sqrt_solve. + apply_map(block_to_row_map(), solve_and_fill_one_block); return output; } @@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const { return n; } +inline bool +BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const { + return blocks == other.blocks; +} /* * Block Diagonal */