Skip to content

Commit

Permalink
refactor serial pbtrf implementation details and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Feb 19, 2025
1 parent b5ec4ab commit 365e70c
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 309 deletions.
13 changes: 8 additions & 5 deletions batched/dense/impl/KokkosBatched_Pbtrf_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

namespace Impl {
template <typename ABViewType>
KOKKOS_INLINE_FUNCTION static int checkPbtrfInput([[maybe_unused]] const ABViewType &Ab) {
static_assert(Kokkos::is_view_v<ABViewType>, "KokkosBatched::pbtrf: ABViewType is not a Kokkos::View.");
Expand All @@ -41,6 +41,7 @@ KOKKOS_INLINE_FUNCTION static int checkPbtrfInput([[maybe_unused]] const ABViewT
#endif
return 0;
}
} // namespace Impl

//// Lower ////
template <>
Expand All @@ -51,11 +52,12 @@ struct SerialPbtrf<Uplo::Lower, Algo::Pbtrf::Unblocked> {
const int n = Ab.extent(1);
if (n == 0) return 0;

auto info = checkPbtrfInput(Ab);
auto info = Impl::checkPbtrfInput(Ab);
if (info) return info;

const int kd = Ab.extent(0) - 1;
return SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(), kd);
return Impl::SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(),
kd);
}
};

Expand All @@ -68,11 +70,12 @@ struct SerialPbtrf<Uplo::Upper, Algo::Pbtrf::Unblocked> {
const int n = Ab.extent(1);
if (n == 0) return 0;

auto info = checkPbtrfInput(Ab);
auto info = Impl::checkPbtrfInput(Ab);
if (info) return info;

const int kd = Ab.extent(0) - 1;
return SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(), kd);
return Impl::SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(n, Ab.data(), Ab.stride_0(), Ab.stride_1(),
kd);
}
};

Expand Down
173 changes: 24 additions & 149 deletions batched/dense/impl/KokkosBatched_Pbtrf_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

#include "KokkosBatched_Util.hpp"
#include "KokkosBlas1_serial_scal_impl.hpp"
#include "KokkosBatched_Syr_Serial_Internal.hpp"
#include "KokkosBatched_Lacgv_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
Expand All @@ -36,17 +39,8 @@ struct SerialPbtrfInternalLower {
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
/**/ ValueType *KOKKOS_RESTRICT AB, const int as0, const int as1,
const int kd);

template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0,
const int as1, const int kd);
};

///
/// Real matrix
///

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(const int an,
Expand All @@ -55,7 +49,7 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::inv
const int kd) {
// Compute the Cholesky factorization A = L*L'.
for (int j = 0; j < an; ++j) {
auto a_jj = AB[0 * as0 + j * as1];
auto a_jj = Kokkos::ArithTraits<ValueType>::real(AB[0 * as0 + j * as1]);

// Check if L (j, j) is positive definite
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
Expand All @@ -75,68 +69,13 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::inv
const ValueType alpha = 1.0 / a_jj;
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[1 * as0 + j * as1]), 1);

// syr (lower) with alpha = -1.0 to diagonal elements
for (int k = 0; k < kn; ++k) {
auto x_k = AB[(k + 1) * as0 + j * as1];
if (x_k != 0) {
auto temp = -1.0 * x_k;
for (int i = k; i < kn; ++i) {
auto x_i = AB[(i + 1) * as0 + j * as1];
AB[i * as0 + (j + 1 + k - i) * as1] += x_i * temp;
}
}
}
}
}

return 0;
}

///
/// Complex matrix
///
template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalLower<Algo::Pbtrf::Unblocked>::invoke(
const int an,
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0, const int as1, const int kd) {
// Compute the Cholesky factorization A = L*L**H
for (int j = 0; j < an; ++j) {
auto a_jj = AB[0 * as0 + j * as1].real();

// Check if L (j, j) is positive definite
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (a_jj <= 0) {
AB[0 * as0 + j * as1] = a_jj;
return j + 1;
}
#endif

a_jj = Kokkos::sqrt(a_jj);
AB[0 * as0 + j * as1] = a_jj;

// Compute elements J+1:J+KN of column J and update the
// trailing submatrix within the band.
int kn = Kokkos::min(kd, an - j - 1);
if (kn > 0) {
// scale to diagonal elements
const ValueType alpha = 1.0 / a_jj;
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[1 * as0 + j * as1]), 1);

// zher (lower) with alpha = -1.0 to diagonal elements
for (int k = 0; k < kn; ++k) {
auto x_k = AB[(k + 1) * as0 + j * as1];
if (x_k != 0) {
auto temp = -1.0 * Kokkos::conj(x_k);
AB[k * as0 + (j + 1) * as1] = AB[k * as0 + (j + 1) * as1].real() + (temp * x_k).real();
for (int i = k + 1; i < kn; ++i) {
auto x_i = AB[(i + 1) * as0 + j * as1];
AB[i * as0 + (j + 1 + k - i) * as1] += x_i * temp;
}
} else {
AB[k * as0 + (j + 1) * as1] = AB[k * as0 + (j + 1) * as1].real();
}
}
// syr or zher (lower) with alpha = -1.0 to diagonal elements
using op = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj,
KokkosBlas::Impl::OpID>;
using op_sym = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpReal,
KokkosBlas::Impl::OpID>;
SerialSyrInternalLower::invoke(op(), op_sym(), kn, -1.0, &(AB[1 * as0 + j * as1]), as0,
&(AB[0 * as0 + (j + 1) * as1]), as0, (as1 - as0));
}
}

Expand All @@ -153,16 +92,8 @@ struct SerialPbtrfInternalUpper {
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
/**/ ValueType *KOKKOS_RESTRICT AB, const int as0, const int as1,
const int kd);

template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int an,
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0,
const int as1, const int kd);
};

///
/// Real matrix
///
template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(const int an,
Expand All @@ -171,7 +102,7 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::inv
const int kd) {
// Compute the Cholesky factorization A = U'*U.
for (int j = 0; j < an; ++j) {
auto a_jj = AB[kd * as0 + j * as1];
auto a_jj = Kokkos::ArithTraits<ValueType>::real(AB[kd * as0 + j * as1]);

// Check if U (j,j) is positive definite
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
Expand All @@ -191,82 +122,26 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::inv
const ValueType alpha = 1.0 / a_jj;
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[(kd - 1) * as0 + (j + 1) * as1]), kld);

// syr (upper) with alpha = -1.0 to diagonal elements
for (int k = 0; k < kn; ++k) {
auto x_k = AB[(k + kd - 1) * as0 + (j + 1 - k) * as1];
if (x_k != 0) {
auto temp = -1.0 * x_k;
for (int i = 0; i < k + 1; ++i) {
auto x_i = AB[(i + kd - 1) * as0 + (j + 1 - i) * as1];
AB[(kd + i) * as0 + (j + 1 + k - i) * as1] += x_i * temp;
}
}
}
}
}

return 0;
}

///
/// Complex matrix
///
template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrfInternalUpper<Algo::Pbtrf::Unblocked>::invoke(
const int an,
/**/ Kokkos::complex<ValueType> *KOKKOS_RESTRICT AB, const int as0, const int as1, const int kd) {
// Compute the Cholesky factorization A = U**H * U.
for (int j = 0; j < an; ++j) {
auto a_jj = AB[kd * as0 + j * as1].real();

// Check if U (j,j) is positive definite
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (a_jj <= 0) {
AB[kd * as0 + j * as1] = a_jj;
return j + 1;
}
#endif

a_jj = Kokkos::sqrt(a_jj);
AB[kd * as0 + j * as1] = a_jj;

// Compute elements J+1:J+KN of row J and update the
// trailing submatrix within the band.
int kn = Kokkos::min(kd, an - j - 1);
int kld = Kokkos::max(1, as0 - 1);
if (kn > 0) {
// scale to diagonal elements
const ValueType alpha = 1.0 / a_jj;
KokkosBlas::Impl::SerialScaleInternal::invoke(kn, alpha, &(AB[(kd - 1) * as0 + (j + 1) * as1]), kld);

// zlacgv to diagonal elements
for (int i = 0; i < kn; ++i) {
AB[(i + kd - 1) * as0 + (j + 1 - i) * as1] = Kokkos::conj(AB[(i + kd - 1) * as0 + (j + 1 - i) * as1]);
}
// zlacgv to diagonal elements (no op for real matrix)
SerialLacgvInternal::invoke(kn, &(AB[(kd - 1) * as0 + (j + 1) * as1]), (as0 - as1));

// zher (upper) with alpha = -1.0 to diagonal elements
for (int k = 0; k < kn; ++k) {
auto x_k = AB[(k + kd - 1) * as0 + (j + 1 - k) * as1];
if (x_k != 0) {
auto temp = -1.0 * Kokkos::conj(x_k);
for (int i = 0; i < k + 1; ++i) {
auto x_i = AB[(i + kd - 1) * as0 + (j + 1 - i) * as1];
AB[(kd + i) * as0 + (j + 1 + k - i) * as1] += x_i * temp;
}
}
}
// syr or zher (upper) with alpha = -1.0 to diagonal elements
using op = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj,
KokkosBlas::Impl::OpID>;
using op_sym = std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpReal,
KokkosBlas::Impl::OpID>;
SerialSyrInternalUpper::invoke(op(), op_sym(), kn, -1.0, &(AB[(kd - 1) * as0 + (j + 1) * as1]), as0,
&(AB[kd * as0 + (j + 1) * as1]), as0, (as1 - as0));

// zlacgv to diagonal elements
for (int i = 0; i < kn; ++i) {
AB[(i + kd - 1) * as0 + (j + 1 - i) * as1] = Kokkos::conj(AB[(i + kd - 1) * as0 + (j + 1 - i) * as1]);
}
// zlacgv to diagonal elements (no op for real matrix)
SerialLacgvInternal::invoke(kn, &(AB[(kd - 1) * as0 + (j + 1) * as1]), (as0 - as1));
}
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PBTRF_SERIAL_INTERNAL_HPP_
11 changes: 9 additions & 2 deletions batched/dense/src/KokkosBatched_Pbtrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ namespace KokkosBatched {
/// L is lower triangular.
/// This is the unblocked version of the algorithm, calling Level 2 BLAS.
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D
/// view
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Pbtrf::Blocked) or unblocked
/// (KokkosBatched::Algo::Pbtrf::Unblocked) algorithm to be used
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D view
///
/// \param ab [inout]: ab is a ldab by n banded matrix, with ( kd + 1 ) diagonals
///
Expand All @@ -43,6 +46,10 @@ namespace KokkosBatched {

template <typename ArgUplo, typename ArgAlgo>
struct SerialPbtrf {
static_assert(
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
"KokkosBatched::pbtrf: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
static_assert(std::is_same_v<ArgAlgo, Algo::Pbtrf::Unblocked>, "KokkosBatched::pbtrf: Use Algo::Pbtrf::Unblocked");
template <typename ABViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ABViewType &ab);
};
Expand Down
2 changes: 0 additions & 2 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@
#include "Test_Batched_SerialPttrs_Real.hpp"
#include "Test_Batched_SerialPttrs_Complex.hpp"
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrf_Real.hpp"
#include "Test_Batched_SerialPbtrf_Complex.hpp"
#include "Test_Batched_SerialPbtrs.hpp"
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"
Expand Down
Loading

0 comments on commit 365e70c

Please sign in to comment.