Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pjessesco committed Dec 30, 2024
1 parent 340d9e0 commit 8f1f844
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 46 deletions.
71 changes: 39 additions & 32 deletions include/Peanut/impl/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ namespace Peanut {
static constexpr Index Col = C;

/**
* @brief Constructor without any parameters initialize to zero matrix.
* @brief Constructor without any initialization
*/
Matrix() {m_data.d1.fill(t_0);}
Matrix() {}

/**
* @brief Constructor with row-major elements.
Expand All @@ -114,13 +114,15 @@ namespace Peanut {
template <typename ...TList>
requires std::conjunction_v<std::is_same<T, TList>...> &&
(sizeof...(TList) == Row*Col)
Matrix(TList ... tlist) : m_data{{std::forward<T>(tlist)...}} {}
Matrix(TList ... tlist) : m_data{std::forward<T>(tlist)...} {}

/**
* @brief Constructor with std::array.
* @param data A std::array having `T` type and \p R * \p C size.
*/
explicit Matrix(const std::array<T, R * C> &data) : m_data{data} {}
explicit Matrix(const std::array<T, R * C> &data) {
memcpy(m_data.data(), data.data(), sizeof(T)*R*C);
}

/**
* @brief Constructor with std::vector.
Expand All @@ -138,26 +140,30 @@ namespace Peanut {
Matrix(const MatrixExpr<E> &expr) requires is_equal_type_size_v<E, Matrix>{
for(Index r=0;r< R;r++){
for(Index c=0;c< C;c++){
m_data.d2[r][c] = expr.elem(r, c);
m_data[r*C+c] = expr.elem(r, c);
}
}
}

/**
* @brief Equivalent with `Matrix::Matrix()`, but for explicit purpose.
* @brief Factory function for zero matrix
* @return Zero matrix with given \p R and \p C .
*/
static Matrix zeros() {return Matrix();}
static Matrix zeros() {
auto m = Matrix();
memset(m.m_data.data(), 0, sizeof(T)*R*C);
return m;
}

/**
* @brief Construct identity matrix. Available only for square matrix case.
* @return Identity matrix with given \p R and \p C .
*/
static Matrix identity() requires is_square_v<Matrix> {
Matrix a;
a.m_data.d1.fill(t_0);
memset(a.m_data.data(), 0, sizeof(T)*R*C);
for (Index i = 0; i < R; i++) {
a.m_data.d2[i][i] = t_1;
a.m_data[i*C+i] = t_1;
}
return a;
}
Expand Down Expand Up @@ -185,7 +191,7 @@ namespace Peanut {
int idx = 0;
constexpr size_t copy_byte = sizeof(Type) * Col;
for(const Matrix<Type, 1, Col> p : {rlist...}){
memcpy(ret.m_data.d2[idx], p.m_data.d2, copy_byte);
memcpy(&(ret.m_data[idx*C]), p.m_data.data(), copy_byte);
idx++;
}
return ret;
Expand Down Expand Up @@ -215,7 +221,7 @@ namespace Peanut {
int c = 0;
for(const Matrix<Type, Row, 1> p : {clist...}){
for(int r=0;r<Row;r++){
ret.m_data.d2[r][c] = p.m_data.d1[r];
ret.m_data[r*C+c] = p.m_data[r];
}
c++;
}
Expand All @@ -229,7 +235,7 @@ namespace Peanut {
* @return Rvalue of an element in \p r 'th Row and \p c 'th column.
*/
INLINE T elem(Index r, Index c) const{
return m_data.d2[r][c];
return m_data[r*C+c];
}

/**
Expand All @@ -241,7 +247,7 @@ namespace Peanut {
* @return Reference of an element in \p r 'th Row and \p c 'th column.
*/
INLINE T& elem(Index r, Index c) {
return m_data.d2[r][c];
return m_data[r*C+c];
}

/**
Expand All @@ -251,7 +257,7 @@ namespace Peanut {
*/
Matrix<Type, 1, Col> get_row(Index idx) const{
Matrix<Type, 1, Col> ret;
memcpy(ret.m_data.d2, m_data.d2[idx], sizeof(Type)*Col);
memcpy(ret.m_data.data(), &(m_data[idx*C]), sizeof(Type)*Col);
return ret;
}

Expand All @@ -261,7 +267,7 @@ namespace Peanut {
* @param row Row matrix which will be assigned to the r'th row of the matrix.
*/
void set_row(Index idx, const Matrix<Type, 1, Col> &row){
memcpy(m_data.d2[idx], row.m_data.d2, sizeof(Type)*Col);
memcpy(&(m_data[idx*C]), row.m_data.data(), sizeof(Type)*Col);
}

/**
Expand All @@ -272,7 +278,7 @@ namespace Peanut {
Matrix<Type, Row, 1> get_col(Index idx) const{
Matrix<Type, Row, 1> ret;
for(int i=0;i<Row;i++){
ret.m_data.d1[i] = m_data.d2[i][idx];
ret.m_data[i] = m_data[i*C+idx];
}
return ret;
}
Expand All @@ -284,7 +290,7 @@ namespace Peanut {
*/
void set_col(Index idx, const Matrix<Type, Row, 1> &col){
for(int i=0;i<Row;i++){
m_data.d2[i][idx] = col.m_data.d1[i];
m_data[i*C+idx] = col.m_data[i];
}
}

Expand All @@ -308,7 +314,7 @@ namespace Peanut {
*/
INLINE T operator[](Index i) const
requires (Row==1) || (Col==1){
return m_data.d1[i];
return m_data[i];
}

/**
Expand All @@ -319,7 +325,7 @@ namespace Peanut {
*/
INLINE T& operator[](Index i)
requires (Row==1) || (Col==1){
return m_data.d1[i];
return m_data[i];
}

/**
Expand All @@ -331,7 +337,7 @@ namespace Peanut {
T dot(const Matrix &vec) const requires (Row==1) || (Col==1){
T ret = t_0;
for(int i=0;i<Row*Col;i++){
ret += (vec.m_data.d1[i] * m_data.d1[i]);
ret += (vec.m_data[i] * m_data[i]);
}
return ret;
}
Expand All @@ -344,7 +350,7 @@ namespace Peanut {
Float length() const requires (Row==1) || (Col==1){
T ret = t_0;
for(int i=0;i<Row*Col;i++){
ret += (m_data.d1[i] * m_data.d1[i]);
ret += (m_data[i] * m_data[i]);
}
return std::sqrt(ret);
}
Expand All @@ -369,7 +375,7 @@ namespace Peanut {
* @return Max element in the vector.
*/
T max() const requires (Row==1) || (Col==1){
return *std::max_element(m_data.d1.begin(), m_data.d1.end());
return *std::max_element(m_data.begin(), m_data.end());
}

/**
Expand All @@ -378,7 +384,7 @@ namespace Peanut {
* @return Min element in the vector.
*/
T min() const requires (Row==1) || (Col==1){
return *std::min_element(m_data.d1.begin(), m_data.d1.end());
return *std::min_element(m_data.begin(), m_data.end());
}

/**
Expand All @@ -400,7 +406,7 @@ namespace Peanut {
*/
friend std::ostream &operator<<(std::ostream &os, const Matrix &matrix) {
for(int i=0;i< R * C;i++){
os << matrix.m_data.d1[i]<<" ";
os << matrix.m_data[i]<<" ";
}
return os;
}
Expand Down Expand Up @@ -448,15 +454,15 @@ namespace Peanut {
*/
constexpr T det() const requires is_square_v<Matrix>{
if constexpr(C ==1){
return m_data.d2[0][0];
return m_data[0];
}
else if constexpr (C ==2){
return m_data.d2[0][0] * m_data.d2[1][1] - m_data.d2[0][1] * m_data.d2[1][0];
return m_data[0] * m_data[C+1] - m_data[1] * m_data[C];
}
else{
T ret = static_cast<T>(0);
for_<C>([&] (auto c) {
ret += (c.value % 2 ? -1 : 1) * m_data.d2[0][c.value] * SubMat<0, c.value>(*this).eval().det();
ret += (c.value % 2 ? -1 : 1) * m_data[c.value] * SubMat<0, c.value>(*this).eval().det();
});
return ret;
}
Expand All @@ -470,7 +476,7 @@ namespace Peanut {
*/
constexpr T det2() const requires is_square_v<Matrix>{
if constexpr(C ==1){
return m_data.d2[0][0];
return m_data[0];
}
else if constexpr (C ==2){
return elem(0, 0) * elem(1, 1) - elem(0, 1) * elem(1, 0);
Expand All @@ -484,10 +490,11 @@ namespace Peanut {
}

// Matrix data
union {
std::array<T, R * C> d1;
T d2[R][C];
} m_data;
std::array<T, R*C> m_data;
// union {
// std::array<T, R * C> d1;
// T d2[R][C];
// } m_data;

private:
static constexpr T t_1 = static_cast<T>(1);
Expand Down
14 changes: 0 additions & 14 deletions test/test_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,6 @@
#include "catch_amalgamated.hpp"


TEST_CASE("Default constructor : zero matrix"){
Peanut::Matrix<int, 2, 2> zero_22_int_mat;
CHECK(zero_22_int_mat.elem(0, 0) == 0);
CHECK(zero_22_int_mat.elem(0, 1) == 0);
CHECK(zero_22_int_mat.elem(1, 0) == 0);
CHECK(zero_22_int_mat.elem(1, 1) == 0);

Peanut::Matrix<float, 2, 2> zero_22_float_mat;
CHECK(zero_22_float_mat.elem(0, 0) == Catch::Approx(0.0f));
CHECK(zero_22_float_mat.elem(0, 1) == Catch::Approx(0.0f));
CHECK(zero_22_float_mat.elem(1, 0) == Catch::Approx(0.0f));
CHECK(zero_22_float_mat.elem(1, 1) == Catch::Approx(0.0f));
}

TEST_CASE("Construct using parameter pack"){
Peanut::Matrix<int, 2, 2> intmat(1,2,3,4);
CHECK(intmat.elem(0, 0) == 1);
Expand Down

0 comments on commit 8f1f844

Please sign in to comment.