-
Notifications
You must be signed in to change notification settings - Fork 99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement batched serial syr #2497
Implement batched serial syr #2497
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are there not test or implementation for the non-transpose case? Are we expecting that users will not want non-transpose case?
template <>
struct SerialSyr<Uplo::Lower, Trans::NoTranspose>
...
Actually, the Maybe I can introduce // Dummy interface, we need transpose for {s,d,c,z}syr or {c,z}her
template <>
struct SerialSyr<Uplo::Upper, Trans::NoTranspose> {
template <typename ScalarType, typename XViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType /*alpha*/, const XViewType /*&x*/, const AViewType /*&A*/) {
static_assert(false, "KokkosBatched::syr: Use Trans::Transpose for {s,d,c,z}syr or Trans::ConjTranspose for {c,z}her");
}
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some small changes are needed to be more specific about the supported options.
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
Signed-off-by: Yuuichi Asahi <[email protected]>
d7ac236
to
b48a9b2
Compare
@lucbv |
This PR implements syr function.
Following files are added:
KokkosBatched_Syr_Serial_Impl.hpp
: Internal interfacesKokkosBatched_Syr_Serial_Internal.hpp
: Implementation detailsKokkosBatched_Syr.hpp
: APIsTest_Batched_SerialSyr.hpp
: Unit tests for thatDetailed description
It performs the rank 1 operation
A:= alpha*x*x**T + A
({s,d,c,z}syr) orA:= alpha*x*x**H + A
({c,z}her)Here, the matrix has the following shape.
x
:(batch_count, n)
On entry, it contains the n elements of x.
A
:(batch_count, lda, n)
On entry, the leading n by n part of the array A must contain the matrix of coefficients.
On exit, A is overwritten by the updated matrix.
It should be noted that only the Upper or Lower triangular part is modified.
Imaginary parts of the diagonal elements are suppressed for {c, z}her.
Parallelization would be made in the following manner. This is efficient only when
A is given in
LayoutLeft
for GPUs andLayoutRight
for CPUs (parallelized over batch direction).Tests
x
, andA
, while copyingA
intoA_ref
. The referenceA_ref
is computed byA_ref:= alpha*x*x**T + A_ref
orA_ref:= alpha*x*x**H + A_ref
at host. Finally, we confirmA
computed by serial ger andA_ref
are the same. A == 0 case is tested as well.x
, andA
as follows to confirmA
is updated as expected. Both upper and lower cases are tested.