From 26f810b4a46978ae8fe880b7d9c04c512a56b433 Mon Sep 17 00:00:00 2001 From: Yuuichi Asahi Date: Fri, 7 Feb 2025 19:28:56 +0900 Subject: [PATCH] implement batched serial syr Signed-off-by: Yuuichi Asahi --- .../impl/KokkosBatched_Syr_Serial_Impl.hpp | 146 ++++++ .../KokkosBatched_Syr_Serial_Internal.hpp | 88 ++++ batched/dense/src/KokkosBatched_Syr.hpp | 49 ++ .../dense/unit_test/Test_Batched_Dense.hpp | 1 + .../unit_test/Test_Batched_SerialSyr.hpp | 434 ++++++++++++++++++ 5 files changed, 718 insertions(+) create mode 100644 batched/dense/impl/KokkosBatched_Syr_Serial_Impl.hpp create mode 100644 batched/dense/impl/KokkosBatched_Syr_Serial_Internal.hpp create mode 100644 batched/dense/src/KokkosBatched_Syr.hpp create mode 100644 batched/dense/unit_test/Test_Batched_SerialSyr.hpp diff --git a/batched/dense/impl/KokkosBatched_Syr_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Syr_Serial_Impl.hpp new file mode 100644 index 0000000000..f40efa63f7 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Syr_Serial_Impl.hpp @@ -0,0 +1,146 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ +#define KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ + +#include +#include +#include "KokkosBatched_Syr_Serial_Internal.hpp" + +namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkSyrInput([[maybe_unused]] const XViewType &x, + [[maybe_unused]] const AViewType &A) { + static_assert(Kokkos::is_view_v, "KokkosBatched::syr: XViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::syr: AViewType is not a Kokkos::View."); + static_assert(XViewType::rank == 1, "KokkosBatched::syr: XViewType must have rank 1."); + static_assert(AViewType::rank == 2, "KokkosBatched::syr: AViewType must have rank 2."); +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + const int lda = A.extent_int(0), n = A.extent_int(1); + const int m = x.extent_int(0); + + if (n < 0) { + Kokkos::printf( + "KokkosBatched::syr: input parameter n must not be less than 0: n " + "= " + "%d\n", + n); + return 1; + } + + if (x.extent_int(0) != n) { + Kokkos::printf( + "KokkosBatched::syr: x must contain n elements: n " + "= " + "%d\n", + n); + return 1; + } + + if (lda < Kokkos::max(1, n)) { + Kokkos::printf( + "KokkosBatched::syr: leading dimension of A must not be smaller than " + "max(1, n): " + "lda = %d, n = %d\n", + lda, n); + return 1; + } +#endif + return 0; +} +} // namespace Impl + +// {s,d,c,z}syr interface +// L T +// A: alpha * x * x**T + A +template <> +struct SerialSyr { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { + // Quick return if possible + const int n = A.extent_int(1); + if (n == 0 || (alpha == ScalarType(0))) return 0; + + auto info = Impl::checkSyrInput(x, A); + if (info) return info; + + return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), + x.stride(0), A.data(), A.stride(0), A.stride(1)); + } +}; + +// {s,d,c,z}syr interface +// U T +// A: alpha * x * x**T + A +template <> +struct SerialSyr { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { + // Quick return if possible + const int n = A.extent_int(1); + if (n == 0 || (alpha == ScalarType(0))) return 0; + + auto info = Impl::checkSyrInput(x, A); + if (info) return info; + + return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(), + x.stride(0), A.data(), A.stride(0), A.stride(1)); + } +}; + +// {c,z}her interface +// L C +// A: alpha * x * x**H + A +template <> +struct SerialSyr { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { + // Quick return if possible + const int n = A.extent_int(1); + if (n == 0 || (alpha == ScalarType(0))) return 0; + + auto info = Impl::checkSyrInput(x, A); + if (info) return info; + + return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, + x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); + } +}; + +// {c,z}her interface +// U C +// A: alpha * x * x**H + A +template <> +struct SerialSyr { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) { + // Quick return if possible + const int n = A.extent_int(1); + if (n == 0 || (alpha == ScalarType(0))) return 0; + + auto info = Impl::checkSyrInput(x, A); + if (info) return info; + + return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha, + x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1)); + } +}; + +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_ diff --git a/batched/dense/impl/KokkosBatched_Syr_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Syr_Serial_Internal.hpp new file mode 100644 index 0000000000..21316217dd --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Syr_Serial_Internal.hpp @@ -0,0 +1,88 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ +#define KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ + +#include + +namespace KokkosBatched { +namespace Impl { + +/// +/// Serial Internal Impl +/// ==================== + +/// Lower + +struct SerialSyrInternalLower { + template + KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); +}; + +template +KOKKOS_INLINE_FUNCTION int SerialSyrInternalLower::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + for (int j = 0; j < an; j++) { + if (x[j * xs0] != ValueType(0)) { + auto temp = alpha * op(x[j * xs0]); + A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); + for (int i = j + 1; i < an; i++) { + A[i * as0 + j * as1] += x[i * xs0] * temp; + } + } else { + A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); + } + } + + return 0; +} + +/// Upper + +struct SerialSyrInternalUpper { + template + KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT A, const int as0, const int as1); +}; + +template +KOKKOS_INLINE_FUNCTION int SerialSyrInternalUpper::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha, + const ValueType *KOKKOS_RESTRICT x, const int xs0, + ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) { + for (int j = 0; j < an; j++) { + if (x[j * xs0] != ValueType(0)) { + auto temp = alpha * op(x[j * xs0]); + for (int i = 0; i < j; i++) { + A[i * as0 + j * as1] += x[i * xs0] * temp; + } + A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp); + } else { + A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]); + } + } + + return 0; +} + +} // namespace Impl +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_ diff --git a/batched/dense/src/KokkosBatched_Syr.hpp b/batched/dense/src/KokkosBatched_Syr.hpp new file mode 100644 index 0000000000..da4bf40513 --- /dev/null +++ b/batched/dense/src/KokkosBatched_Syr.hpp @@ -0,0 +1,49 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#ifndef KOKKOSBATCHED_SYR_HPP_ +#define KOKKOSBATCHED_SYR_HPP_ + +#include + +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) + +namespace KokkosBatched { + +/// \brief Serial Batched Syr: +/// Performs the symmetric rank 1 operation +/// A := alpha*x*x**T + A or A := alpha*x*x**H + A +/// where alpha is a scalar, x is an n element vector, and A is a n by n symmetric or Hermitian matrix. +/// +/// \tparam ScalarType: Input type for the scalar alpha +/// \tparam XViewType: Input type for the vector x, needs to be a 1D view +/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view +/// +/// \param alpha [in]: alpha is a scalar +/// \param x [in]: x is a length n vector, a rank 1 view +/// \param A [inout]: A is a n by n matrix, a rank 2 view +/// +/// No nested parallel_for is used inside of the function. +/// +template +struct SerialSyr { + template + KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &a); +}; +} // namespace KokkosBatched + +#include "KokkosBatched_Syr_Serial_Impl.hpp" + +#endif // KOKKOSBATCHED_SYR_HPP_ diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index e9086f1c93..5270bcee3b 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -66,6 +66,7 @@ #include "Test_Batched_SerialGetrf.hpp" #include "Test_Batched_SerialGetrs.hpp" #include "Test_Batched_SerialGer.hpp" +#include "Test_Batched_SerialSyr.hpp" // Team Kernels #include "Test_Batched_TeamAxpy.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialSyr.hpp b/batched/dense/unit_test/Test_Batched_SerialSyr.hpp new file mode 100644 index 0000000000..1f70dcaf67 --- /dev/null +++ b/batched/dense/unit_test/Test_Batched_SerialSyr.hpp @@ -0,0 +1,434 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) +#include +#include +#include +#include +#include +#include "Test_Batched_DenseUtils.hpp" + +namespace Test { +namespace Syr { + +template +struct ParamTag { + using uplo = U; + using trans = T; +}; + +template +struct Functor_BatchedSerialSyr { + using execution_space = typename DeviceType::execution_space; + XViewType m_x; + AViewType m_A; + ScalarType m_alpha; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialSyr(const ScalarType alpha, const XViewType &x, const AViewType &A) + : m_x(x), m_A(A), m_alpha(alpha) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int k, int &info) const { + auto sub_x = Kokkos::subview(m_x, k, Kokkos::ALL()); + auto sub_A = Kokkos::subview(m_A, k, Kokkos::ALL(), Kokkos::ALL()); + + info += KokkosBatched::SerialSyr::invoke(m_alpha, sub_x, + sub_A); + } + + inline int run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialSyr"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + int info_sum = 0; + Kokkos::Profiling::pushRegion(name.c_str()); + Kokkos::RangePolicy policy(0, m_A.extent(0)); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); + Kokkos::Profiling::popRegion(); + return info_sum; + } +}; + +/// \brief Implementation details of batched syr analytical test +/// to confirm A:= x*x**T + A is computed correctly +/// \param Nb [in] Batch size +/// alpha = 1.5 +/// 4x4 matrix (upper) +/// U: [[1, -3, -2, 0], +/// [0, 1, -3, -2], +/// [0, 0, 1, -3], +/// [0, 0, 0, 1]] +/// x: [1, 2, 3, 4] +/// Ref: [[ 2.5, 0., 2.5, 6., ], +/// [ 0., 7., 6., 10., ], +/// [ 0., 0., 14.5, 15., ], +/// [ 0., 0., 0., 25., ]] +/// +/// 4x4 matrix (lower) +/// L: [[1, 0, 0, 0], +/// [-1, 1, 0, 0], +/// [2, -1, 1, 0], +/// [0, 2, -1, 1]] +/// x: [1, 2, 3, 4] +/// Ref: [[ 2.5, 0., 0., 0., ], +/// [ 2., 7., 0., 0., ], +/// [ 6.5, 8., 14.5, 0., ], +/// [ 6., 14., 17., 25., ]] +/// +/// \param Nb [in] Batch size of matrices +template +void impl_test_batched_syr_analytical(const std::size_t Nb) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + using View3DType = Kokkos::View; + using ArgUplo = typename ParamTagType::uplo; + + const std::size_t BlkSize = 4; + View3DType A("A", Nb, BlkSize, BlkSize), A_s("A_s", Nb, BlkSize, BlkSize), A_ref("A_ref", Nb, BlkSize, BlkSize); + + View2DType x("x", Nb, BlkSize); + + const std::size_t incx = 2; + // Testing incx argument with strided views + Kokkos::LayoutStride layout{Nb, incx, BlkSize, Nb * incx}; + StridedView2DType x_s("x_s", layout); + + // Only filling x2, A2 and deep_copy from its subview + auto h_A = Kokkos::create_mirror_view(A); + auto h_A_ref = Kokkos::create_mirror_view(A_ref); + auto h_x = Kokkos::create_mirror_view(x); + + for (std::size_t ib = 0; ib < Nb; ib++) { + h_A(ib, 0, 0) = 1; + h_A(ib, 0, 1) = -3; + h_A(ib, 0, 2) = -2; + h_A(ib, 0, 3) = 0; + h_A(ib, 1, 0) = -1; + h_A(ib, 1, 1) = 1; + h_A(ib, 1, 2) = -3; + h_A(ib, 1, 3) = -2; + h_A(ib, 2, 0) = 2; + h_A(ib, 2, 1) = -1; + h_A(ib, 2, 2) = 1; + h_A(ib, 2, 3) = -3; + h_A(ib, 3, 0) = 0; + h_A(ib, 3, 1) = 2; + h_A(ib, 3, 2) = -1; + h_A(ib, 3, 3) = 1; + + if (std::is_same_v) { + h_A_ref(ib, 0, 0) = 2.5; + h_A_ref(ib, 0, 1) = 0; + h_A_ref(ib, 0, 2) = 2.5; + h_A_ref(ib, 0, 3) = 6; + h_A_ref(ib, 1, 0) = 0; + h_A_ref(ib, 1, 1) = 7; + h_A_ref(ib, 1, 2) = 6; + h_A_ref(ib, 1, 3) = 10; + h_A_ref(ib, 2, 0) = 0; + h_A_ref(ib, 2, 1) = 0; + h_A_ref(ib, 2, 2) = 14.5; + h_A_ref(ib, 2, 3) = 15.; + h_A_ref(ib, 3, 0) = 0; + h_A_ref(ib, 3, 1) = 0; + h_A_ref(ib, 3, 2) = 0; + h_A_ref(ib, 3, 3) = 25.; + } else { + h_A_ref(ib, 0, 0) = 2.5; + h_A_ref(ib, 0, 1) = 0; + h_A_ref(ib, 0, 2) = 0; + h_A_ref(ib, 0, 3) = 0; + h_A_ref(ib, 1, 0) = 2; + h_A_ref(ib, 1, 1) = 7; + h_A_ref(ib, 1, 2) = 0; + h_A_ref(ib, 1, 3) = 0; + h_A_ref(ib, 2, 0) = 6.5; + h_A_ref(ib, 2, 1) = 8; + h_A_ref(ib, 2, 2) = 14.5; + h_A_ref(ib, 2, 3) = 0; + h_A_ref(ib, 3, 0) = 6; + h_A_ref(ib, 3, 1) = 14; + h_A_ref(ib, 3, 2) = 17; + h_A_ref(ib, 3, 3) = 25; + } + + for (std::size_t j = 0; j < BlkSize; j++) { + h_x(ib, j) = static_cast(j + 1); + } + } + + Kokkos::deep_copy(A, h_A); + Kokkos::deep_copy(x, h_x); + + // Upper or lower diagnoal part of A into A_s + create_triangular_matrix(A, A_s); + + Kokkos::deep_copy(A, A_s); + + // Deep copy to strided views + Kokkos::deep_copy(x_s, x); + + const ScalarType alpha = 1.5; + + auto info = Functor_BatchedSerialSyr(alpha, x, A).run(); + + Kokkos::fence(); + EXPECT_EQ(info, 0); + + // With strided views + info = Functor_BatchedSerialSyr(alpha, x_s, A_s) + .run(); + + Kokkos::fence(); + EXPECT_EQ(info, 0); + + RealType eps = 1.0e1 * ats::epsilon(); + Kokkos::deep_copy(h_A, A); + auto h_A_s = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A_s); + + // Check if A:= alpha * x * y**T + A + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + for (std::size_t j = 0; j < BlkSize; j++) { + EXPECT_NEAR_KK(h_A(ib, i, j), h_A_ref(ib, i, j), eps); + EXPECT_NEAR_KK(h_A_s(ib, i, j), h_A_ref(ib, i, j), eps); + } + } + } +} + +/// \brief Implementation details of batched syr test +/// +/// \param Nb [in] Batch size of matrices +/// \param BlkSize [in] Block size of matrix A +template +void impl_test_batched_syr(const std::size_t Nb, const std::size_t BlkSize) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using View2DType = Kokkos::View; + using StridedView2DType = Kokkos::View; + using View3DType = Kokkos::View; + using ArgUplo = typename ParamTagType::uplo; + + View3DType A("A", Nb, BlkSize, BlkSize), A0("A0", Nb, BlkSize, BlkSize), A_s("A_s", Nb, BlkSize, BlkSize), + A0_s("A0_s", Nb, BlkSize, BlkSize), A_ref("A_ref", Nb, BlkSize, BlkSize), A0_ref("A0_ref", Nb, BlkSize, BlkSize); + + View2DType x("x", Nb, BlkSize, BlkSize); + + const std::size_t incx = 2; + // Testing incx argument with strided views + Kokkos::LayoutStride layout{Nb, incx, BlkSize, Nb * incx}; + StridedView2DType x_s("x_s", layout); + + // Create a random matrix A and x + using execution_space = typename DeviceType::execution_space; + Kokkos::Random_XorShift64_Pool rand_pool(13718); + ScalarType randStart, randEnd; + + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::fill_random(x, rand_pool, randStart, randEnd); + + // Upper or lower triangular part of A + create_triangular_matrix(A, A_ref); + + Kokkos::deep_copy(A, A_ref); + + // Deep copy to strided views + Kokkos::deep_copy(A_s, A); + Kokkos::deep_copy(x_s, x); + + // When A0 is zero + const ScalarType alpha = 1.5; + auto info0 = + Functor_BatchedSerialSyr(alpha, x, A0).run(); + + // When A is a random matrix + auto info1 = + Functor_BatchedSerialSyr(alpha, x, A).run(); + + Kokkos::fence(); + EXPECT_EQ(info0, 0); + EXPECT_EQ(info1, 0); + + // With strided Views + info0 = + Functor_BatchedSerialSyr(alpha, x_s, A0_s) + .run(); + + // When A is a random matrix + info1 = Functor_BatchedSerialSyr(alpha, x_s, A_s) + .run(); + + // Make a reference at host + auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); + auto h_A_ref = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A_ref); + auto h_A0_ref = Kokkos::create_mirror_view(Kokkos::HostSpace(), A0_ref); + + // Note: ConjTranspose corresponds to {c,z}her for Hermitian matrix + const bool is_conj = std::is_same_v; + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t j = 0; j < BlkSize; j++) { + if (h_x(ib, j) != 0) { + auto temp = is_conj ? alpha * Kokkos::ArithTraits::conj(h_x(ib, j)) : alpha * h_x(ib, j); + + if (std::is_same_v) { + for (std::size_t i = 0; i < j + 1; i++) { + h_A_ref(ib, i, j) = h_A_ref(ib, i, j) + h_x(ib, i) * temp; + h_A0_ref(ib, i, j) = h_x(ib, i) * temp; + } + } else { + for (std::size_t i = j; i < BlkSize; i++) { + h_A_ref(ib, i, j) = h_A_ref(ib, i, j) + h_x(ib, i) * temp; + h_A0_ref(ib, i, j) = h_x(ib, i) * temp; + } + } + h_A_ref(ib, j, j) = is_conj ? Kokkos::ArithTraits::real(h_A_ref(ib, j, j)) : h_A_ref(ib, j, j); + h_A0_ref(ib, j, j) = is_conj ? Kokkos::ArithTraits::real(h_A0_ref(ib, j, j)) : h_A0_ref(ib, j, j); + } + } + } + + RealType eps = 1.0e1 * ats::epsilon(); + + auto h_A = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A); + auto h_A0 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A0); + auto h_A_s = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A_s); + auto h_A0_s = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A0_s); + + // Check if A:= alpha * x * y**T + A or A:= alpha * x * y**H + A + for (std::size_t ib = 0; ib < Nb; ib++) { + for (std::size_t i = 0; i < BlkSize; i++) { + for (std::size_t j = 0; j < BlkSize; j++) { + EXPECT_NEAR_KK(h_A(ib, i, j), h_A_ref(ib, i, j), eps); + EXPECT_NEAR_KK(h_A0(ib, i, j), h_A0_ref(ib, i, j), eps); + EXPECT_NEAR_KK(h_A_s(ib, i, j), h_A_ref(ib, i, j), eps); + EXPECT_NEAR_KK(h_A0_s(ib, i, j), h_A0_ref(ib, i, j), eps); + } + } + } +} + +} // namespace Syr +} // namespace Test + +template +int test_batched_syr() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + using LayoutType = Kokkos::LayoutLeft; + Test::Syr::impl_test_batched_syr_analytical(1); + Test::Syr::impl_test_batched_syr_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Syr::impl_test_batched_syr(1, i); + Test::Syr::impl_test_batched_syr(2, i); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + using LayoutType = Kokkos::LayoutRight; + Test::Syr::impl_test_batched_syr_analytical(1); + Test::Syr::impl_test_batched_syr_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Syr::impl_test_batched_syr(1, i); + Test::Syr::impl_test_batched_syr(2, i); + } + } +#endif + + return 0; +} + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_F(TestCategory, test_batched_syr_l_t_float) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_l_c_float) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_u_t_float) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_u_c_float) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_F(TestCategory, test_batched_syr_l_t_double) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_l_c_double) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_u_t_double) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +TEST_F(TestCategory, test_batched_syr_u_c_double) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_F(TestCategory, test_batched_syr_l_t_fcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_l_c_fcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_u_t_fcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_u_c_fcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_F(TestCategory, test_batched_syr_l_t_dcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_l_c_dcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_u_t_dcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +TEST_F(TestCategory, test_batched_syr_u_c_dcomplex) { + using param_tag_type = ::Test::Syr::ParamTag; + test_batched_syr, param_tag_type>(); +} +#endif