Skip to content

Commit

Permalink
add block diag methods
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman committed Jan 11, 2023
1 parent 884903e commit c33d436
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions include/albatross/src/linalg/block_diagonal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ struct BlockDiagonalLDLT {
solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
ThreadPool *pool) const;

template <class _Scalar, int _Rows, int _Cols>
Eigen::Matrix<_Scalar, _Rows, _Cols>
sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
ThreadPool *pool) const;
template <typename Derived>
Eigen::MatrixXd sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
ThreadPool *pool) const;

BlockDiagonal sqrt_transpose() const;

Expand All @@ -51,6 +50,8 @@ struct BlockDiagonalLDLT {
Eigen::Index rows() const;

Eigen::Index cols() const;

bool operator==(const BlockDiagonalLDLT &other) const;
};

struct BlockDiagonal {
Expand Down Expand Up @@ -141,15 +142,16 @@ BlockDiagonalLDLT::solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
return output;
}

template <class _Scalar, int _Rows, int _Cols>
inline Eigen::Matrix<_Scalar, _Rows, _Cols>
BlockDiagonalLDLT::sqrt_solve(const Eigen::Matrix<_Scalar, _Rows, _Cols> &rhs,
template <typename Derived>
inline Eigen::MatrixXd
BlockDiagonalLDLT::sqrt_solve(const Eigen::DenseBase<Derived> &rhs,
ThreadPool *pool) const {
ALBATROSS_ASSERT(cols() == rhs.rows());
Eigen::Matrix<_Scalar, _Rows, _Cols> output(rows(), rhs.cols());
Eigen::MatrixXd output(rows(), rhs.cols());

auto solve_and_fill_one_block = [&](const size_t i, const Eigen::Index row) {
const auto rhs_chunk = rhs.block(row, 0, blocks[i].rows(), rhs.cols());
const auto rhs_chunk =
rhs.derived().block(row, 0, blocks[i].rows(), rhs.cols());
output.block(row, 0, blocks[i].rows(), rhs.cols()) =
blocks[i].sqrt_solve(rhs_chunk);
};
Expand Down Expand Up @@ -182,6 +184,10 @@ inline Eigen::Index BlockDiagonalLDLT::cols() const {
return n;
}

inline bool
BlockDiagonalLDLT::operator==(const BlockDiagonalLDLT &other) const {
return blocks == other.blocks;
}
/*
* Block Diagonal
*/
Expand Down

0 comments on commit c33d436

Please sign in to comment.