Skip to content

Commit

Permalink
add block diag methods
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman committed Dec 31, 2022
1 parent 22514f6 commit 45c69b9
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions include/albatross/src/linalg/block_diagonal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
ThreadPool *pool) const;

template <class _Scalar, int _Rows, int _Cols>
Eigen::Matrix<_Scalar, _Rows, _Cols>
sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
ThreadPool *pool) const;
template <typename Derived>
Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
ThreadPool *pool) const;

BlockDiagonal sqrt_transpose() const;

Expand All @@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
Eigen::Index rows() const;

Eigen::Index cols() const;

bool operator==(const BlockDiagonalLDLT &other) const;
};

struct BlockDiagonal {
Expand Down Expand Up @@ -141,20 +142,23 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
return output;
}

template <class _Scalar, int _Rows, int _Cols>
inline Eigen::Matrix<_Scalar, _Rows, _Cols>
BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
template <typename Derived>
inline Eigen::MatrixXd
BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
ThreadPool *pool) const {
ALBATROSS_ASSERT(cols() == rhs.rows());
Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols());
Eigen::MatrixXd output(rows(), rhs.cols());

auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols());
const auto rhs_chunk =
rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols());
output.block(row, 0, blocks[i].rows(), rhs.cols()) =
blocks[i].sqrt_solve(rhs_chunk);
};

apply_map(block_to_row_map(), solve_and_fill_one_block, pool);
// Intentionally leaving pool out here due to an unknown bug
// in which the thread pool version crashes in sqrt_solve.
apply_map(block_to_row_map(), solve_and_fill_one_block);
return output;
}

Expand Down Expand Up @@ -182,6 +186,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
return n;
}

inline bool
BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const {
return blocks == other.blocks;
}
/*
* Block Diagonal
*/
Expand Down

0 comments on commit 45c69b9

Please sign in to comment.