Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor serial tbsv implementation details and tests #2478

Merged
merged 5 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

Expand Down Expand Up @@ -50,8 +51,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::inv
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

// Solve L**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

return 0;
}
Expand All @@ -76,8 +78,9 @@ KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::inv
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve U**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);
using op =
std::conditional_t<Kokkos::ArithTraits<ValueType>::is_complex, KokkosBlas::Impl::OpConj, KokkosBlas::Impl::OpID>;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(op(), false, an, A, as0, as1, x, xs0, kd);

// Solve U*X = B, overwriting B with X.
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);
Expand Down
45 changes: 26 additions & 19 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@

/// \author Yuuichi Asahi ([email protected])

#include "KokkosBlas_util.hpp"
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

namespace KokkosBatched {

namespace Impl {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const XViewType &x, [[maybe_unused]] const int k) {
static_assert(Kokkos::is_view<AViewType>::value, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::tbsv: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::tbsv: XViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::tbsv: AViewType must have rank 2.");
static_assert(XViewType::rank == 1, "KokkosBatched::tbsv: XViewType must have rank 1.");

Expand Down Expand Up @@ -63,15 +64,17 @@ KOKKOS_INLINE_FUNCTION static int checkTbsvInput([[maybe_unused]] const AViewTyp
return 0;
}

} // namespace Impl

//// Lower non-transpose ////
template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -81,11 +84,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -94,11 +98,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Lower, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -107,10 +112,10 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::NoTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
return Impl::SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
}
};
Expand All @@ -120,11 +125,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::Transpose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, false, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpID(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand All @@ -133,11 +139,12 @@ template <typename ArgDiag>
struct SerialTbsv<Uplo::Upper, Trans::ConjTranspose, ArgDiag, Algo::Tbsv::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x, const int k) {
auto info = checkTbsvInput(A, x, k);
auto info = Impl::checkTbsvInput(A, x, k);
if (info) return info;

return SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
ArgDiag::use_unit_diag, true, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), k);
return Impl::SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
KokkosBlas::Impl::OpConj(), ArgDiag::use_unit_diag, A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(),
x.stride_0(), k);
}
};

Expand Down
58 changes: 18 additions & 40 deletions batched/dense/impl/KokkosBatched_Tbsv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
#include "KokkosBatched_Util.hpp"

namespace KokkosBatched {

namespace Impl {
///
/// Serial Internal Impl
/// ====================

///
/// Lower, Non-Transpose
/// Lower
///

template <typename AlgoType>
Expand Down Expand Up @@ -70,49 +70,37 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalLowerTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = an - 1; j >= 0; --j) {
auto temp = x[j * xs0];

if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= A[(i - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[0 + j * as1];
for (int i = Kokkos::min(an - 1, j + k); i > j; --i) {
temp -= op(A[(i - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

///
/// Upper, Non-Transpose
/// Upper
///

template <typename AlgoType>
Expand Down Expand Up @@ -154,46 +142,36 @@ KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invok

template <typename AlgoType>
struct SerialTbsvInternalUpperTranspose {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const bool use_unit_diag, const bool do_conj, const int an,
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, const bool use_unit_diag, const int an,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k);
};

template <>
template <typename ValueType>
template <typename Op, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(
const bool use_unit_diag, const bool do_conj, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
Op op, const bool use_unit_diag, const int an, const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int k) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int j = 0; j < an; j++) {
auto temp = x[j * xs0];
if (do_conj) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= Kokkos::ArithTraits<ValueType>::conj(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / Kokkos::ArithTraits<ValueType>::conj(A[k * as0 + j * as1]);
} else {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= A[(i + k - j) * as0 + j * as1] * x[i * xs0];
}
if (!use_unit_diag) temp = temp / A[k * as0 + j * as1];
for (int i = Kokkos::max(0, j - k); i < j; ++i) {
temp -= op(A[(i + k - j) * as0 + j * as1]) * x[i * xs0];
}
if (!use_unit_diag) temp = temp / op(A[k * as0 + j * as1]);
x[j * xs0] = temp;
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_TBSV_SERIAL_INTERNAL_HPP_
19 changes: 19 additions & 0 deletions batched/dense/src/KokkosBatched_Tbsv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ namespace KokkosBatched {
/// non-unit, upper or lower triangular band matrix, with ( k + 1 )
/// diagonals.
///
/// \tparam ArgUplo: Type indicating whether A is the upper (Uplo::Upper) or lower (Uplo::Lower) triangular matrix
/// \tparam ArgTrans: Type indicating the equations to be solved as follows
/// - ArgTrans::NoTranspose: A * X = B
/// - ArgTrans::Transpose: A**T * X = B
/// - ArgTrans::ConjTranspose: A**H * X = B
/// \tparam ArgDiag: Type indicating whether A is the unit (Diag::Unit) or non-unit (Diag::NonUnit) triangular matrix
/// \tparam ArgAlgo: Type indicating the blocked (KokkosBatched::Algo::Tbsv::Blocked) or unblocked
/// (KokkosBatched::Algo::Tbsv::Unblocked) algorithm to be used
///
/// \tparam AViewType: Input type for the matrix, needs to be a 2D view
/// \tparam XViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
Expand All @@ -43,6 +52,16 @@ namespace KokkosBatched {

template <typename ArgUplo, typename ArgTrans, typename ArgDiag, typename ArgAlgo>
struct SerialTbsv {
static_assert(
std::is_same_v<ArgUplo, Uplo::Upper> || std::is_same_v<ArgUplo, Uplo::Lower>,
"KokkosBatched::tbsv: Use Uplo::Upper for upper triangular matrix or Uplo::Lower for lower triangular matrix");
static_assert(std::is_same_v<ArgTrans, Trans::NoTranspose> || std::is_same_v<ArgTrans, Trans::Transpose> ||
std::is_same_v<ArgTrans, Trans::ConjTranspose>,
"KokkosBatched::tbsv: Use Trans::NoTranspose, Trans::Transpose or Trans::ConjTranspose");
static_assert(
std::is_same_v<ArgDiag, Diag::Unit> || std::is_same_v<ArgDiag, Diag::NonUnit>,
"KokkosBatched::tbsv: Use Diag::Unit for unit triangular matrix or Diag::NonUnit for non-unit triangular matrix");
static_assert(std::is_same_v<ArgAlgo, Algo::Tbsv::Unblocked>, "KokkosBatched::tbsv: Use Algo::Tbsv::Unblocked");
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &X, const int k);
};
Expand Down
Loading
Loading