Skip to content

Commit

Permalink
Merge pull request kokkos#7330 from ldh4/crtrott/fix-minmaxloc
Browse files Browse the repository at this point in the history
Set an initial value index during join of MincLoc, MaxLoc or MinMaxLoc
  • Loading branch information
crtrott authored Sep 23, 2024
2 parents e9ab977 + 35f9a83 commit 5f115f4
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 4 deletions.
20 changes: 18 additions & 2 deletions core/src/Kokkos_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,12 @@ struct MinLoc {
// Required
KOKKOS_INLINE_FUNCTION
void join(value_type& dest, const value_type& src) const {
if (src.val < dest.val) dest = src;
if (src.val < dest.val)
dest = src;
else if (src.val == dest.val &&
dest.loc == reduction_identity<index_type>::min()) {
dest.loc = src.loc;
}
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -493,7 +498,12 @@ struct MaxLoc {
// Required
KOKKOS_INLINE_FUNCTION
void join(value_type& dest, const value_type& src) const {
if (src.val > dest.val) dest = src;
if (src.val > dest.val)
dest = src;
else if (src.val == dest.val &&
dest.loc == reduction_identity<index_type>::min()) {
dest.loc = src.loc;
}
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -620,10 +630,16 @@ struct MinMaxLoc {
if (src.min_val < dest.min_val) {
dest.min_val = src.min_val;
dest.min_loc = src.min_loc;
} else if (dest.min_val == src.min_val &&
dest.min_loc == reduction_identity<index_type>::min()) {
dest.min_loc = src.min_loc;
}
if (src.max_val > dest.max_val) {
dest.max_val = src.max_val;
dest.max_loc = src.max_loc;
} else if (dest.max_val == src.max_val &&
dest.max_loc == reduction_identity<index_type>::min()) {
dest.max_loc = src.max_loc;
}
}

Expand Down
20 changes: 18 additions & 2 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_Reducer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ struct OpenMPTargetReducerWrapper<MinLoc<Scalar, Index, Space>> {
// Required
KOKKOS_INLINE_FUNCTION
static void join(value_type& dest, const value_type& src) {
if (src.val < dest.val) dest = src;
if (src.val < dest.val)
dest = src;
else if (src.val == dest.val &&
dest.loc == reduction_identity<index_type>::min()) {
dest.loc = src.loc;
}
}

KOKKOS_INLINE_FUNCTION
Expand All @@ -215,7 +220,12 @@ struct OpenMPTargetReducerWrapper<MaxLoc<Scalar, Index, Space>> {

KOKKOS_INLINE_FUNCTION
static void join(value_type& dest, const value_type& src) {
if (src.val > dest.val) dest = src;
if (src.val > dest.val)
dest = src;
else if (src.val == dest.val &&
dest.loc == reduction_identity<index_type>::min()) {
dest.loc = src.loc;
}
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -268,10 +278,16 @@ struct OpenMPTargetReducerWrapper<MinMaxLoc<Scalar, Index, Space>> {
if (src.min_val < dest.min_val) {
dest.min_val = src.min_val;
dest.min_loc = src.min_loc;
} else if (dest.min_val == src.min_val &&
dest.min_loc == reduction_identity<index_type>::min()) {
dest.min_loc = src.min_loc;
}
if (src.max_val > dest.max_val) {
dest.max_val = src.max_val;
dest.max_loc = src.max_loc;
} else if (dest.max_val == src.max_val &&
dest.max_loc == reduction_identity<index_type>::min()) {
dest.max_loc = src.max_loc;
}
}

Expand Down
Loading

0 comments on commit 5f115f4

Please sign in to comment.