Skip to content

Commit

Permalink
fix dmat construction, add test (#298)
Browse files Browse the repository at this point in the history
  • Loading branch information
victor1234 authored Oct 27, 2021
1 parent ee32e11 commit b537e84
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
53 changes: 45 additions & 8 deletions modules/transform/wavelet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,23 +677,60 @@ blaze::CompressedMatrix<T> DaubechiesMat(size_t size, int order = 4)
c[i] = c[i] * coeff;
}

auto mat = blaze::CompressedMatrix<T>(size, size, 0);
auto mat = blaze::CompressedMatrix<T>(size, size);
mat.reserve(size * c.size());
for (size_t i = 0; i < size / 2; ++i) {
for (size_t ci = 0; ci < c.size(); ++ci) {
mat.append(i, (i * 2 + ci) % size, c[ci]);

size_t ci = mat.columns() - 2 * i;
if (ci > c.size()) {
ci = 0;
}
std::cout << "ci" << ci << std::endl;
for (size_t a = 0; a < c.size(); ++a) {
if (ci >= c.size()) {
ci = ci % c.size();
}

size_t j = i * 2 + ci;
if (j >= mat.columns()) {
j = j % mat.columns();
}

mat.append(i, j, c[ci]);
std::cout << i << " " << j << " " << ci << " " << c[ci] << std::endl;

++ci;
}
mat.finalize(i);

}

for (size_t i = size / 2; i < size; ++i) {
for (size_t i = 0; i < size / 2; ++i) {
int sign = 1;

for (size_t ci = 0; ci < c.size(); ++ci) {
mat.append(i, (i * 2 + ci) % size, c[order - 1 - ci] * sign);
sign *= -1;

size_t ci = mat.columns() - 2 * i;
if (ci > c.size()) {
ci = 0;
}
mat.finalize(i);
std::cout << "ci" << ci << std::endl;
for (size_t a = 0; a < c.size(); ++a) {
if (ci >= c.size()) {
ci = ci % c.size();
}

size_t j = i * 2 + ci;
if (j >= mat.columns()) {
j = j % mat.columns();
}

mat.append(size / 2 + i, j, c[order - 1 - ci]);
std::cout << i << " " << j << " " << ci << " " << c[ci] << std::endl;

++ci;
sign *= -1;
}
mat.finalize(size / 2 + i);
}

return mat;
Expand Down
1 change: 1 addition & 0 deletions tests/transform_test/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#target_sources(unit_tests PRIVATE hog_tests.cpp)
target_sources(unit_tests PRIVATE dwt_tests.cpp)
target_sources(unit_tests PRIVATE wavelet_tests.cpp)


#configure_file(${CMAKE_CURRENT_SOURCE_DIR}/astronaut.pgm ${CMAKE_CURRENT_BINARY_DIR}/astronaut.pgm COPYONLY)
Expand Down
19 changes: 19 additions & 0 deletions tests/transform_test/wavelet_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <catch2/catch.hpp>

#include "modules/transform/wavelet.hpp"

TEMPLATE_TEST_CASE("DaubechiesMat", "[transform][wavelet]", float)
{
const auto dmat = wavelet::DaubechiesMat<TestType>(12, 6);
std::cout << dmat << std::endl;

const auto sumLow = blaze::sum(blaze::row(dmat, 0));

for (size_t i = 0; i < dmat.rows() / 2; ++i) {
REQUIRE(Approx(sumLow) == blaze::sum(blaze::row(dmat, i)));
}
const auto sumHigh = blaze::sum(blaze::row(dmat, 0));
for (size_t i = dmat.rows() / 2; i < dmat.rows(); ++i) {
REQUIRE(Approx(sumHigh) == blaze::sum(blaze::row(dmat, i)));
}
}

0 comments on commit b537e84

Please sign in to comment.