Skip to content

Commit

Permalink
refactor batched serial pbtrs implementation details and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Feb 20, 2025
1 parent e56863a commit b0dc43e
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 142 deletions.
14 changes: 10 additions & 4 deletions batched/dense/src/KokkosBatched_Pbtrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ namespace KokkosBatched {
/// where U is an upper triangular matrix, U**H is the transpose of U, and
/// L is lower triangular matrix, L**H is the transpose of L.
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D
/// view
/// \tparam BViewType: Input type for a right-hand side and the solution,
/// needs to be a 1D view
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Pbtrs::Blocked) or unblocked
/// (KokkosBatched::Algo::Pbtrs::Unblocked) algorithm to be used
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D view
/// \tparam BViewType: Input type for a right-hand side and the solution, needs to be a 1D view
///
/// \param ab [in]: ab is a ldab by n banded matrix, with ( kd + 1 ) diagonals
/// \param b [inout]: right-hand side and the solution, a rank 1 view
Expand All @@ -45,6 +47,10 @@ namespace KokkosBatched {

template <typename ArgUplo, typename ArgAlgo>
struct SerialPbtrs {
static_assert(
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
"KokkosBatched::pbtrs: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
static_assert(std::is_same_v<ArgAlgo, Algo::Pbtrs::Unblocked>, "KokkosBatched::pbtrs: Use Algo::Pbtrs::Unblocked");
template <typename ABViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ABViewType &ab, const BViewType &b);
};
Expand Down
2 changes: 0 additions & 2 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@
#include "Test_Batched_SerialPttrs_Complex.hpp"
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrs.hpp"
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"
#include "Test_Batched_SerialLaswp.hpp"
#include "Test_Batched_SerialIamax.hpp"
#include "Test_Batched_SerialGetrf.hpp"
Expand Down
163 changes: 117 additions & 46 deletions batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include "KokkosBatched_Pbtrs.hpp"
#include "Test_Batched_DenseUtils.hpp"

using namespace KokkosBatched;

namespace Test {
namespace Pbtrs {

Expand All @@ -36,14 +34,14 @@ struct ParamTag {
template <typename DeviceType, typename ABViewType, typename ParamTagType, typename AlgoTagType>
struct Functor_BatchedSerialPbtrf {
using execution_space = typename DeviceType::execution_space;
ABViewType _ab;
ABViewType m_ab;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialPbtrf(const ABViewType &ab) : _ab(ab) {}
Functor_BatchedSerialPbtrf(const ABViewType &ab) : m_ab(ab) {}

KOKKOS_INLINE_FUNCTION
void operator()(const ParamTagType &, const int k) const {
auto sub_ab = Kokkos::subview(_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto sub_ab = Kokkos::subview(m_ab, k, Kokkos::ALL(), Kokkos::ALL());

KokkosBatched::SerialPbtrf<typename ParamTagType::uplo, AlgoTagType>::invoke(sub_ab);
}
Expand All @@ -53,24 +51,24 @@ struct Functor_BatchedSerialPbtrf {
std::string name_region("KokkosBatched::Test::SerialPbtrs");
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _ab.extent(0));
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, m_ab.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};

template <typename DeviceType, typename ABViewType, typename BViewType, typename ParamTagType, typename AlgoTagType>
struct Functor_BatchedSerialPbtrs {
using execution_space = typename DeviceType::execution_space;
ABViewType _ab;
BViewType _b;
ABViewType m_ab;
BViewType m_b;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialPbtrs(const ABViewType &ab, const BViewType &b) : _ab(ab), _b(b) {}
Functor_BatchedSerialPbtrs(const ABViewType &ab, const BViewType &b) : m_ab(ab), m_b(b) {}

KOKKOS_INLINE_FUNCTION
void operator()(const ParamTagType &, const int k, int &info) const {
auto sub_ab = Kokkos::subview(_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto bb = Kokkos::subview(_b, k, Kokkos::ALL());
auto sub_ab = Kokkos::subview(m_ab, k, Kokkos::ALL(), Kokkos::ALL());
auto bb = Kokkos::subview(m_b, k, Kokkos::ALL());

info += KokkosBatched::SerialPbtrs<typename ParamTagType::uplo, AlgoTagType>::invoke(sub_ab, bb);
}
Expand All @@ -82,7 +80,7 @@ struct Functor_BatchedSerialPbtrs {
std::string name = name_region + name_value_type;
int info_sum = 0;
Kokkos::Profiling::pushRegion(name.c_str());
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _b.extent(0));
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, m_b.extent(0));
Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum);
Kokkos::Profiling::popRegion();
return info_sum;
Expand All @@ -92,70 +90,88 @@ struct Functor_BatchedSerialPbtrs {
template <typename DeviceType, typename ScalarType, typename AViewType, typename xViewType, typename yViewType>
struct Functor_BatchedSerialGemv {
using execution_space = typename DeviceType::execution_space;
AViewType _a;
xViewType _x;
yViewType _y;
ScalarType _alpha, _beta;
AViewType m_a;
xViewType m_x;
yViewType m_y;
ScalarType m_alpha, m_beta;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta,
const yViewType &y)
: _a(a), _x(x), _y(y), _alpha(alpha), _beta(beta) {}
: m_a(a), m_x(x), m_y(y), m_alpha(alpha), m_beta(beta) {}

KOKKOS_INLINE_FUNCTION
void operator()(const int k) const {
auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL());
auto xx = Kokkos::subview(_x, k, Kokkos::ALL());
auto yy = Kokkos::subview(_y, k, Kokkos::ALL());
auto aa = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL());
auto xx = Kokkos::subview(m_x, k, Kokkos::ALL());
auto yy = Kokkos::subview(m_y, k, Kokkos::ALL());

KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(_alpha, aa, xx, _beta, yy);
KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(m_alpha, aa, xx, m_beta, yy);
}

inline void run() {
using value_type = typename AViewType::non_const_value_type;
std::string name_region("KokkosBatched::Test::SerialPbtrs");
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space> policy(0, _x.extent(0));
Kokkos::RangePolicy<execution_space> policy(0, m_x.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};

template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
/// \brief Implementation details of batched pbtrs test
/// Confirm A * x = b, where
/// Confirm A * x = b, where
/// A: [[4, 1, 0],
/// [1, 4, 1],
/// [0, 1, 4]]
/// b: [1, 1, 1]
/// x: [3/14, 1/7, 3/14]
///
/// This corresponds to the following system of equations:
/// This corresponds to the following system of equations:
/// 4 x0 + x1 = 1
/// x0 + 4 x1 + x2 = 1
/// x1 + 4 x2 = 1
///
/// We confirm this with the factorized band matrix Ub or Lb.
/// For upper banded storage, Ab = Ub**H * Ub
/// Ub: [[0, 1/sqrt(4), 1/sqrt(4 - (1/sqrt(4))**2)],
/// [sqrt(4), sqrt(4 - (1/sqrt(4))**2), sqrt(4 - 1/sqrt(4 - (1/sqrt(4))**2))],]
/// For lower banded storage, Ab = Lb * Lb**H
/// Lb: [[sqrt(4), sqrt(4 - (1/sqrt(4))**2), sqrt(4 - 1/sqrt(4 - (1/sqrt(4))**2))],
/// [1/sqrt(4), 1/sqrt(4 - (1/sqrt(4))**2), 0],]
///
/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix)
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A
/// \param BlkSize [in] Block size of matrix A
template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
void impl_test_batched_pbtrs_analytical(const int N) {
using ats = typename Kokkos::ArithTraits<ScalarType>;
using RealType = typename ats::mag_type;
using View2DType = Kokkos::View<ScalarType **, LayoutType, DeviceType>;
using View3DType = Kokkos::View<ScalarType ***, LayoutType, DeviceType>;

constexpr int BlkSize = 3, k = 1;
View3DType A("A", N, BlkSize, BlkSize), A_reconst("A_reconst", N, BlkSize, BlkSize);
View3DType Ab("Ab", N, k + 1, BlkSize); // Banded matrix
View2DType x0("x0", N, BlkSize), x_ref("x_ref", N, BlkSize), y0("y0", N, BlkSize); // Solutions
const int BlkSize = 3, k = 1;
View3DType Ab("Ab", N, k + 1, BlkSize); // In band storage
View2DType x0("x0", N, BlkSize), x_ref("x_ref", N, BlkSize); // Solutions

auto h_A_reconst = Kokkos::create_mirror_view(A_reconst);
auto h_x_ref = Kokkos::create_mirror_view(x_ref);
auto h_Ab = Kokkos::create_mirror_view(Ab);
auto h_x_ref = Kokkos::create_mirror_view(x_ref);

for (int ib = 0; ib < N; ib++) {
for (int i = 0; i < BlkSize; i++) {
for (int j = 0; j < BlkSize; j++) {
h_A_reconst(ib, i, j) = i == j ? 4.0 : 1.0;
}
if (std::is_same_v<typename ParamTagType::uplo, KokkosBatched::Uplo::Upper>) {
// Ub
h_Ab(ib, 1, 0) = Kokkos::sqrt(4.0);
h_Ab(ib, 0, 1) = 1.0 / h_Ab(ib, 1, 0);
h_Ab(ib, 1, 1) = Kokkos::sqrt(4.0 - h_Ab(ib, 0, 1) * h_Ab(ib, 0, 1));
h_Ab(ib, 0, 2) = 1.0 / h_Ab(ib, 1, 1);
h_Ab(ib, 1, 2) = Kokkos::sqrt(4.0 - h_Ab(ib, 0, 2) * h_Ab(ib, 0, 2));
} else {
// Lb
h_Ab(ib, 0, 0) = Kokkos::sqrt(4.0);
h_Ab(ib, 1, 0) = 1.0 / h_Ab(ib, 0, 0);
h_Ab(ib, 0, 1) = Kokkos::sqrt(4.0 - h_Ab(ib, 1, 0) * h_Ab(ib, 1, 0));
h_Ab(ib, 1, 1) = 1.0 / h_Ab(ib, 0, 1);
h_Ab(ib, 0, 2) = Kokkos::sqrt(4.0 - h_Ab(ib, 1, 1) * h_Ab(ib, 1, 1));
}

h_x_ref(ib, 0) = 3.0 / 14.0;
Expand All @@ -166,15 +182,7 @@ void impl_test_batched_pbtrs_analytical(const int N) {
Kokkos::fence();

Kokkos::deep_copy(x0, ScalarType(1.0));
Kokkos::deep_copy(A_reconst, h_A_reconst);

// Create banded triangluar matrix in normal and banded storage
using ArgUplo = typename ParamTagType::uplo;
create_banded_pds_matrix<View3DType, View3DType, ArgUplo>(A_reconst, A, k, false);
create_banded_triangular_matrix<View3DType, View3DType, ArgUplo>(A_reconst, Ab, k, true);

// Factorize with Pbtrf: A = U**H * U or A = L * L**H
Functor_BatchedSerialPbtrf<DeviceType, View3DType, ParamTagType, AlgoTagType>(Ab).run();
Kokkos::deep_copy(Ab, h_Ab);

// pbtrs (Note, Ab is a factorized matrix of A)
auto info = Functor_BatchedSerialPbtrs<DeviceType, View3DType, View2DType, ParamTagType, AlgoTagType>(Ab, x0).run();
Expand All @@ -194,13 +202,16 @@ void impl_test_batched_pbtrs_analytical(const int N) {
}
}

template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
/// \brief Implementation details of batched pbtrs test
/// Confirm A * x = b, where
/// Confirm A * x = b, where A is a real symmetric positive definitie
/// or complex Hermitian band matrix. A is storead in a band storage.
/// A must be factorized as A=U**H*U or A=L*L**H (Cholesky factorization)
/// by pbtrf.
///
/// \param N [in] Batch size of RHS (banded matrix can also be batched matrix)
/// \param k [in] Number of superdiagonals or subdiagonals of matrix A
/// \param BlkSize [in] Block size of matrix A
template <typename DeviceType, typename ScalarType, typename LayoutType, typename ParamTagType, typename AlgoTagType>
void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) {
using ats = typename Kokkos::ArithTraits<ScalarType>;
using RealType = typename ats::mag_type;
Expand Down Expand Up @@ -293,3 +304,63 @@ int test_batched_pbtrs() {

return 0;
}

#if defined(KOKKOSKERNELS_INST_FLOAT)
TEST_F(TestCategory, test_batched_pbtrs_l_float) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, float, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_float) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, float, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_DOUBLE)
TEST_F(TestCategory, test_batched_pbtrs_l_double) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, double, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_double) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, double, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT)
TEST_F(TestCategory, test_batched_pbtrs_l_fcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_fcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, Kokkos::complex<float>, param_tag_type, algo_tag_type>();
}
#endif

#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE)
TEST_F(TestCategory, test_batched_pbtrs_l_dcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Lower>;

test_batched_pbtrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
TEST_F(TestCategory, test_batched_pbtrs_u_dcomplex) {
using algo_tag_type = typename Algo::Pbtrs::Unblocked;
using param_tag_type = ::Test::Pbtrs::ParamTag<KokkosBatched::Uplo::Upper>;

test_batched_pbtrs<TestDevice, Kokkos::complex<double>, param_tag_type, algo_tag_type>();
}
#endif
45 changes: 0 additions & 45 deletions batched/dense/unit_test/Test_Batched_SerialPbtrs_Complex.hpp

This file was deleted.

Loading

0 comments on commit b0dc43e

Please sign in to comment.