Skip to content

Commit

Permalink
implement batched serial syr
Browse files Browse the repository at this point in the history
Signed-off-by: Yuuichi Asahi <[email protected]>
  • Loading branch information
Yuuichi Asahi committed Feb 7, 2025
1 parent f469d39 commit 26f810b
Show file tree
Hide file tree
Showing 5 changed files with 718 additions and 0 deletions.
146 changes: 146 additions & 0 deletions batched/dense/impl/KokkosBatched_Syr_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_

#include <KokkosBlas_util.hpp>
#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Syr_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {
template <typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int checkSyrInput([[maybe_unused]] const XViewType &x,
[[maybe_unused]] const AViewType &A) {
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::syr: XViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::syr: AViewType is not a Kokkos::View.");
static_assert(XViewType::rank == 1, "KokkosBatched::syr: XViewType must have rank 1.");
static_assert(AViewType::rank == 2, "KokkosBatched::syr: AViewType must have rank 2.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int lda = A.extent_int(0), n = A.extent_int(1);
const int m = x.extent_int(0);

if (n < 0) {
Kokkos::printf(
"KokkosBatched::syr: input parameter n must not be less than 0: n "
"= "
"%d\n",
n);
return 1;
}

if (x.extent_int(0) != n) {
Kokkos::printf(
"KokkosBatched::syr: x must contain n elements: n "
"= "
"%d\n",
n);
return 1;
}

if (lda < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::syr: leading dimension of A must not be smaller than "
"max(1, n): "
"lda = %d, n = %d\n",
lda, n);
return 1;
}
#endif
return 0;
}
} // namespace Impl

// {s,d,c,z}syr interface
// L T
// A: alpha * x * x**T + A
template <>
struct SerialSyr<Uplo::Lower, Trans::Transpose> {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) {
// Quick return if possible
const int n = A.extent_int(1);
if (n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkSyrInput(x, A);
if (info) return info;

return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(),
x.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

// {s,d,c,z}syr interface
// U T
// A: alpha * x * x**T + A
template <>
struct SerialSyr<Uplo::Upper, Trans::Transpose> {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) {
// Quick return if possible
const int n = A.extent_int(1);
if (n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkSyrInput(x, A);
if (info) return info;

return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpID(), KokkosBlas::Impl::OpID(), n, alpha, x.data(),
x.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

// {c,z}her interface
// L C
// A: alpha * x * x**H + A
template <>
struct SerialSyr<Uplo::Lower, Trans::ConjTranspose> {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) {
// Quick return if possible
const int n = A.extent_int(1);
if (n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkSyrInput(x, A);
if (info) return info;

return Impl::SerialSyrInternalLower::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha,
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

// {c,z}her interface
// U C
// A: alpha * x * x**H + A
template <>
struct SerialSyr<Uplo::Upper, Trans::ConjTranspose> {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &A) {
// Quick return if possible
const int n = A.extent_int(1);
if (n == 0 || (alpha == ScalarType(0))) return 0;

auto info = Impl::checkSyrInput(x, A);
if (info) return info;

return Impl::SerialSyrInternalUpper::invoke(KokkosBlas::Impl::OpConj(), KokkosBlas::Impl::OpReal(), n, alpha,
x.data(), x.stride(0), A.data(), A.stride(0), A.stride(1));
}
};

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_SYR_SERIAL_IMPL_HPP_
88 changes: 88 additions & 0 deletions batched/dense/impl/KokkosBatched_Syr_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

/// Lower

struct SerialSyrInternalLower {
template <typename Op, typename SymOp, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1);
};

template <typename Op, typename SymOp, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialSyrInternalLower::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
for (int j = 0; j < an; j++) {
if (x[j * xs0] != ValueType(0)) {
auto temp = alpha * op(x[j * xs0]);
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp);
for (int i = j + 1; i < an; i++) {
A[i * as0 + j * as1] += x[i * xs0] * temp;
}
} else {
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]);
}
}

return 0;
}

/// Upper

struct SerialSyrInternalUpper {
template <typename Op, typename SymOp, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1);
};

template <typename Op, typename SymOp, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialSyrInternalUpper::invoke(Op op, SymOp sym_op, const int an, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT x, const int xs0,
ValueType *KOKKOS_RESTRICT A, const int as0, const int as1) {
for (int j = 0; j < an; j++) {
if (x[j * xs0] != ValueType(0)) {
auto temp = alpha * op(x[j * xs0]);
for (int i = 0; i < j; i++) {
A[i * as0 + j * as1] += x[i * xs0] * temp;
}
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1] + x[j * xs0] * temp);
} else {
A[j * as0 + j * as1] = sym_op(A[j * as0 + j * as1]);
}
}

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_SYR_SERIAL_INTERNAL_HPP_
49 changes: 49 additions & 0 deletions batched/dense/src/KokkosBatched_Syr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_SYR_HPP_
#define KOKKOSBATCHED_SYR_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi ([email protected])

namespace KokkosBatched {

/// \brief Serial Batched Syr:
/// Performs the symmetric rank 1 operation
/// A := alpha*x*x**T + A or A := alpha*x*x**H + A
/// where alpha is a scalar, x is an n element vector, and A is a n by n symmetric or Hermitian matrix.
///
/// \tparam ScalarType: Input type for the scalar alpha
/// \tparam XViewType: Input type for the vector x, needs to be a 1D view
/// \tparam AViewType: Input/output type for the matrix A, needs to be a 2D view
///
/// \param alpha [in]: alpha is a scalar
/// \param x [in]: x is a length n vector, a rank 1 view
/// \param A [inout]: A is a n by n matrix, a rank 2 view
///
/// No nested parallel_for is used inside of the function.
///
template <typename ArgUplo, typename ArgTrans>
struct SerialSyr {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const XViewType &x, const AViewType &a);
};
} // namespace KokkosBatched

#include "KokkosBatched_Syr_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_SYR_HPP_
1 change: 1 addition & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
#include "Test_Batched_SerialGetrf.hpp"
#include "Test_Batched_SerialGetrs.hpp"
#include "Test_Batched_SerialGer.hpp"
#include "Test_Batched_SerialSyr.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
Loading

0 comments on commit 26f810b

Please sign in to comment.