Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MueLu: Add ComputeNodesInAggregate for Aggregates_kokkos #10902

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ namespace MueLu {
for(LO i=0; i<numNodes; i++) {
LO aggregate = vertex2AggId[i];
if(aggregate !=INVALID) {
aggNodes[aggCurr[aggregate]] = i;
aggCurr[aggregate]++;
aggNodes[aggCurr[aggregate]] = i;
aggCurr[aggregate]++;
}
else {
unaggregated[currNumUnaggregated] = i;
currNumUnaggregated++;
unaggregated[currNumUnaggregated] = i;
currNumUnaggregated++;
}
}
unaggregated.resize(currNumUnaggregated);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ namespace MueLu {
using node_type = Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType>;
using device_type = DeviceType;
using range_type = Kokkos::RangePolicy<local_ordinal_type, execution_space>;
using LO_view = Kokkos::View<local_ordinal_type*, device_type>;

using aggregates_sizes_type = Kokkos::View<LocalOrdinal*, device_type>;

Expand Down Expand Up @@ -259,6 +260,12 @@ namespace MueLu {

local_graph_type GetGraph() const;

/*! @brief Generates a compressed list of nodes in each aggregate, where
the entries in aggNodes[aggPtr[i]] up to aggNodes[aggPtr[i+1]-1] contain the nodes in aggregate i.
unaggregated contains the list of nodes which are, for whatever reason, not aggregated (e.g. Dirichlet)
*/
void ComputeNodesInAggregate(LO_view & aggPtr, LO_view & aggNodes, LO_view & unaggregated) const;

//! @name Overridden from Teuchos::Describable
//@{

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,58 @@ namespace MueLu {

return graph_;
}

template <class LocalOrdinal, class GlobalOrdinal, class DeviceType>
void
Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::ComputeNodesInAggregate(LO_view & aggPtr, LO_view & aggNodes, LO_view & unaggregated) const {
LO numAggs = GetNumAggregates();
LO numNodes = vertex2AggId_->getLocalLength();
auto vertex2AggId = vertex2AggId_->getDeviceLocalView(Xpetra::Access::ReadOnly);
typename aggregates_sizes_type::const_type aggSizes = ComputeAggregateSizes(true);
LO INVALID = Teuchos::OrdinalTraits<LO>::invalid();

aggPtr = LO_view("aggPtr",numAggs+1);
aggNodes = LO_view("aggNodes",numNodes);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is numNodes the number of all aggregated nodes, or the number of all nodes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The number of all nodes. This is the same way it's handled in MueLu_Aggregates from what I recall, meaning there's probably a small amount of padding at the end in the case where there are unaggregated nodes.

LO_view aggCurr("agg curr",numAggs+1);

// Construct the "rowptr" and the counter
Kokkos::parallel_scan("MueLu:Aggregates:ComputeNodesInAggregate:scan", range_type(0,numAggs+1),
KOKKOS_LAMBDA(const LO aggIdx, LO& aggOffset, bool final_pass) {
LO count = 0;
if(aggIdx < numAggs)
count = aggSizes(aggIdx);
if(final_pass) {
aggPtr(aggIdx) = aggOffset;
aggCurr(aggIdx) = aggOffset;
if(aggIdx==numAggs)
aggCurr(numAggs) = 0; // use this for counting unaggregated nodes
}
aggOffset += count;
});

// Preallocate unaggregated to the correct size
LO numUnaggregated = 0;
Kokkos::parallel_reduce("MueLu:Aggregates:ComputeNodesInAggregate:unaggregatedSize", range_type(0,numNodes),
KOKKOS_LAMBDA(const LO nodeIdx, LO & count) {
if(vertex2AggId(nodeIdx,0)==INVALID)
count++;
}, numUnaggregated);
unaggregated = LO_view("unaggregated",numUnaggregated);

// Stick the nodes in each aggregate's spot
Kokkos::parallel_for("MueLu:Aggregates:ComputeNodesInAggregate:for", range_type(0,numNodes),
KOKKOS_LAMBDA(const LO nodeIdx) {
LO aggIdx = vertex2AggId(nodeIdx,0);
if(aggIdx != INVALID) {
// atomic postincrement aggCurr(aggIdx) each time
aggNodes(Kokkos::atomic_fetch_add(&aggCurr(aggIdx),1)) = nodeIdx;
} else {
// same, but using last entry of aggCurr for unaggregated nodes
unaggregated(Kokkos::atomic_fetch_add(&aggCurr(numAggs),1)) = nodeIdx;
}
});

}

template <class LocalOrdinal, class GlobalOrdinal, class DeviceType>
std::string Aggregates_kokkos<LocalOrdinal, GlobalOrdinal, Kokkos::Compat::KokkosDeviceWrapperNode<DeviceType> >::description() const {
Expand Down
14 changes: 14 additions & 0 deletions packages/muelu/test/unit_tests_kokkos/Aggregates_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,20 @@ namespace MueLuTests {
}, numBadAggregates);
TEST_EQUALITY(numBadAggregates, 0);

// Check ComputeNodesInAggregate
typename Aggregates_kokkos::LO_view aggPtr, aggNodes, unaggregated;
aggregates->ComputeNodesInAggregate(aggPtr, aggNodes, unaggregated);
TEST_EQUALITY(aggPtr.extent_int(0), numAggs+1);
// TEST_EQUALITY(unaggregated.extent_int(0), 0); // 1 unaggregated node in the MPI_4 case

// test aggPtr(i)+aggSizes(i)=aggPtr(i+1)
typename Aggregates_kokkos::LO_view::HostMirror aggPtr_h = Kokkos::create_mirror_view(aggPtr);
typename Aggregates_kokkos::aggregates_sizes_type::HostMirror aggSizes_h = Kokkos::create_mirror_view(aggSizes);
Kokkos::deep_copy(aggPtr_h, aggPtr);
Kokkos::deep_copy(aggSizes_h, aggSizes);
for(LO i=0; i<aggSizes_h.extent_int(0); ++i)
TEST_EQUALITY(aggPtr_h(i)+aggSizes_h(i), aggPtr_h(i+1));

} //UncoupledPhase3

TEUCHOS_UNIT_TEST_TEMPLATE_4_DECL(Aggregates_kokkos, AllowDroppingToCreateAdditionalDirichletRows, Scalar, LocalOrdinal, GlobalOrdinal, Node)
Expand Down