-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Yuuichi Asahi <[email protected]>
- Loading branch information
Yuuichi Asahi
committed
Feb 7, 2025
1 parent
f469d39
commit 26f810b
Showing
5 changed files
with
718 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ | ||
#define KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ | ||
|
||
#include <KokkosBlas_util.hpp> | ||
#include <KokkosBatched_Util.hpp> | ||
#include "KokkosBatched_Syr_Serial_Internal.hpp" | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
template <typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int checkSyrInput([[maybe_unused]] const XViewType &x, | ||
[[maybe_unused]] const AViewType &A) { | ||
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::syr: XViewType is not a Kokkos::View."); | ||
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::syr: AViewType is not a Kokkos::View."); | ||
static_assert(XViewType::rank == 1, "KokkosBatched::syr: XViewType must have rank 1."); | ||
static_assert(AViewType::rank == 2, "KokkosBatched::syr: AViewType must have rank 2."); | ||
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) | ||
const int lda = A.extent_int(0), n = A.extent_int(1); | ||
const int m = x.extent_int(0); | ||
|
||
if (n < 0) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: input parameter n must not be less than 0: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (x.extent_int(0) != n) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: x must contain n elements: n " | ||
"= " | ||
"%d\n", | ||
n); | ||
return 1; | ||
} | ||
|
||
if (lda < Kokkos::max(1, n)) { | ||
Kokkos::printf( | ||
"KokkosBatched::syr: leading dimension of A must not be smaller than " | ||
"max(1, n): " | ||
"lda = %d, n = %d\n", | ||
lda, n); | ||
return 1; | ||
} | ||
#endif | ||
return 0; | ||
} | ||
} // namespace Impl | ||
|
||
// {s,d,c,z}syr interface | ||
// L T | ||
// A: alpha * x * x**T + A | ||
template <> | ||
struct SerialSyr<Uplo::Lower, Trans::Transpose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), | ||
x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {s,d,c,z}syr interface | ||
// U T | ||
// A: alpha * x * x**T + A | ||
template <> | ||
struct SerialSyr<Uplo::Upper, Trans::Transpose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), | ||
x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {c,z}her interface | ||
// L C | ||
// A: alpha * x * x**H + A | ||
template <> | ||
struct SerialSyr<Uplo::Lower, Trans::ConjTranspose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, | ||
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
// {c,z}her interface | ||
// U C | ||
// A: alpha * x * x**H + A | ||
template <> | ||
struct SerialSyr<Uplo::Upper, Trans::ConjTranspose> { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { | ||
// Quick return if possible | ||
const int n = A.extent_int(1); | ||
if (n == 0 || (alpha == ScalarType(0))) return 0; | ||
|
||
auto info = Impl::checkSyrInput(x, A); | ||
if (info) return info; | ||
|
||
return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, | ||
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); | ||
} | ||
}; | ||
|
||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
|
||
#ifndef KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ | ||
#define KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
namespace KokkosBatched { | ||
namespace Impl { | ||
|
||
/// | ||
/// Serial Internal Impl | ||
/// ==================== | ||
|
||
/// Lower | ||
|
||
struct SerialSyrInternalLower { | ||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); | ||
}; | ||
|
||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialSyrInternalLower::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { | ||
for (int j = 0; j < an; j++) { | ||
if (x[j * xs0] != ValueType(0)) { | ||
auto temp = alpha * op(x[j * xs0]); | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); | ||
for (int i = j + 1; i < an; i++) { | ||
A[i * as0 + j * as1] += x[i * xs0] * temp; | ||
} | ||
} else { | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
/// Upper | ||
|
||
struct SerialSyrInternalUpper { | ||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); | ||
}; | ||
|
||
template <typename Op, typename SymOp, typename ScalarType, typename ValueType> | ||
KOKKOS_INLINE_FUNCTION int SerialSyrInternalUpper::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, | ||
const ValueType *KOKKOS_RESTRICT x, const int xs0, | ||
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { | ||
for (int j = 0; j < an; j++) { | ||
if (x[j * xs0] != ValueType(0)) { | ||
auto temp = alpha * op(x[j * xs0]); | ||
for (int i = 0; i < j; i++) { | ||
A[i * as0 + j * as1] += x[i * xs0] * temp; | ||
} | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); | ||
} else { | ||
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); | ||
} | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
} // namespace Impl | ||
} // namespace KokkosBatched | ||
|
||
#endif // KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
//@HEADER | ||
// ************************************************************************ | ||
// | ||
// Kokkos v. 4.0 | ||
// Copyright (2022) National Technology & Engineering | ||
// Solutions of Sandia, LLC (NTESS). | ||
// | ||
// Under the terms of Contract DE-NA0003525 with NTESS, | ||
// the U.S. Government retains certain rights in this software. | ||
// | ||
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://kokkos.org/LICENSE for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//@HEADER | ||
#ifndef KOKKOSBATCHED_SYR_HPP_ | ||
#define KOKKOSBATCHED_SYR_HPP_ | ||
|
||
#include <KokkosBatched_Util.hpp> | ||
|
||
/// \author Yuuichi Asahi ([email protected]) | ||
|
||
namespace KokkosBatched { | ||
|
||
/// \brief Serial Batched Syr: | ||
/// Performs the symmetric rank 1 operation | ||
/// A := alpha*x*x**T + A or A := alpha*x*x**H + A | ||
/// where alpha is a scalar, x is an n element vector, and A is a n by n symmetric or Hermitian matrix. | ||
/// | ||
/// \tparam ScalarType: Input type for the scalar alpha | ||
/// \tparam XViewType: Input type for the vector x, needs to be a 1D view | ||
/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view | ||
/// | ||
/// \param alpha [in]: alpha is a scalar | ||
/// \param x [in]: x is a length n vector, a rank 1 view | ||
/// \param A [inout]: A is a n by n matrix, a rank 2 view | ||
/// | ||
/// No nested parallel_for is used inside of the function. | ||
/// | ||
template <typename ArgUplo, typename ArgTrans> | ||
struct SerialSyr { | ||
template <typename ScalarType, typename XViewType, typename AViewType> | ||
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &a); | ||
}; | ||
} // namespace KokkosBatched | ||
|
||
#include "KokkosBatched_Syr_Serial_Impl.hpp" | ||
|
||
#endif // KOKKOSBATCHED_SYR_HPP_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.