Skip to content

Commit

Permalink
BUG: device use bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian-Diaz committed Dec 13, 2024
1 parent 373c943 commit 64c8073
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
14 changes: 5 additions & 9 deletions examples/ann_distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,10 @@ int main(int argc, char* argv[])
TpetraDFArray<real_t> output_grid(100, 2); //array of 2D coordinates for 10 by 10 grid of points

//populate coords
long long int min_global = output_grid.pmap.getMinGlobalIndex();
FOR_ALL(i,0,output_grid.dims(0), {
output_grid(i, 0) = i/10;
output_grid(i, 1) = i%10;
output_grid(i, 0) = (min_global + i)/10;
output_grid(i, 1) = (min_global + i)%10;
}); // end parallel for

output_grid.update_host();
Expand Down Expand Up @@ -418,14 +419,9 @@ int main(int argc, char* argv[])
TpetraPartitionMap<> partitioned_output_map = output_grid.pmap;
TpetraDFArray<real_t> partitioned_output_values(partitioned_output_map, "partitioned output values");

//construct a unique source vector from ANN output using the subview constructor
//(for example's sake this is in fact a copy of the subview wrapped by the output as well)
TpetraDFArray<real_t> sub_output_values(ANNLayers(num_layers-1).distributed_outputs, ANNLayers(num_layers-1).distributed_outputs.comm_pmap,
ANNLayers(num_layers-1).distributed_outputs.comm_pmap.getMinGlobalIndex());

//general communication object between two vectors/arrays
TpetraCommunicationPlan<real_t> output_comms(partitioned_output_values, sub_output_values);
output_comms.execute_comms();
TpetraCommunicationPlan<real_t> output_grid_comms(partitioned_output_values, ANNLayers(num_layers-1).distributed_outputs);
output_grid_comms.execute_comms();
partitioned_output_values.print();

} // end of kokkos scope
Expand Down
8 changes: 7 additions & 1 deletion src/include/tpetra_wrapper_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,8 @@ void TpetraDCArray<T,Layout,ExecSpace,MemoryTraits>::repartition_vector() {
Teuchos::RCP<MV> managed_tpetra_vector = Teuchos::rcp(new MV(tpetra_pmap, this_array_));
managed_tpetra_vector->assign(*tpetra_vector);
tpetra_vector = managed_tpetra_vector;
this_array_.modify_device();
this_array_.sync_host();
// // migrate density vector if this is a restart file read
// if (simparam.restart_file&&repartition_node_densities)
// {
Expand Down Expand Up @@ -2499,7 +2501,7 @@ TpetraDFArray<T,Layout,ExecSpace,MemoryTraits>& TpetraDFArray<T,Layout,ExecSpace
if(temp.order_==1){
dims_[1] = 1;
}

global_dim1_ = temp.global_dim1_;
order_ = temp.order_;
length_ = temp.length_;
Expand Down Expand Up @@ -2723,6 +2725,8 @@ void TpetraDFArray<T,Layout,ExecSpace,MemoryTraits>::repartition_vector() {
Teuchos::RCP<MV> managed_tpetra_vector = Teuchos::rcp(new MV(tpetra_pmap, this_array_));
managed_tpetra_vector->assign(*tpetra_vector);
tpetra_vector = managed_tpetra_vector;
this_array_.modify_device();
this_array_.sync_host();
// // migrate density vector if this is a restart file read
// if (simparam.restart_file&&repartition_node_densities)
// {
Expand Down Expand Up @@ -3336,6 +3340,7 @@ void TpetraCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>::execute_comms(){
else{
destination_vector_.tpetra_vector->doImport(*source_vector_.tpetra_vector, *importer, Tpetra::INSERT);
}
destination_vector_.update_host();
}

template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
Expand Down Expand Up @@ -3458,6 +3463,7 @@ void TpetraLRCommunicationPlan<T,Layout,ExecSpace,MemoryTraits>::execute_comms()
else{
destination_vector_.tpetra_vector->doImport(*source_vector_.tpetra_vector, *importer, Tpetra::INSERT);
}
destination_vector_.update_host();
}

template <typename T, typename Layout, typename ExecSpace, typename MemoryTraits>
Expand Down

0 comments on commit 64c8073

Please sign in to comment.