Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Oct 23, 2024
1 parent 0fdfd4d commit 53fc17b
Showing 1 changed file with 18 additions and 44 deletions.
62 changes: 18 additions & 44 deletions paddle/fluid/operators/controlflow/fetch_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,51 +146,25 @@ class FetchV2Kernel {

bool deepcopy = ctx.Attr<bool>("deepcopy");

if (fetch_var->IsType<phi::DenseTensor>()) {
auto &src_item = fetch_var->Get<phi::DenseTensor>();
if (!src_item.IsInitialized()) {
return;
}
auto *dst_item = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(col)));
bool check_place =
src_item.place().GetType() == phi::AllocationType::CPU ||
src_item.place().GetType() == phi::AllocationType::GPUPINNED ||
src_item.place().GetType() == phi::AllocationType::CUSTOM;
PADDLE_ENFORCE_EQ(
check_place,
true,
common::errors::InvalidArgument("Tensor's place of input(X) must "
"be CPUPlace or CUDAPinnedPlace."));
if (deepcopy) {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
dst_item->ShareDataWith(src_item);
dst_item->set_lod(src_item.lod());
}
} else if (fetch_var->IsType<phi::SparseCooTensor>()) {
auto &src_item = fetch_var->Get<phi::SparseCooTensor>();
if (!src_item.initialized()) {
return;
}
fetch_list->at(col) = src_item;
auto &src_item = fetch_var->Get<phi::DenseTensor>();
if (!src_item.IsInitialized()) {
return;
}
auto *dst_item = &(PADDLE_GET(phi::DenseTensor, fetch_list->at(col)));
bool check_place =
src_item.place().GetType() == phi::AllocationType::CPU ||
src_item.place().GetType() == phi::AllocationType::GPUPINNED ||
src_item.place().GetType() == phi::AllocationType::CUSTOM;
PADDLE_ENFORCE_EQ(
check_place,
true,
common::errors::InvalidArgument("Tensor's place of input(X) must "
"be CPUPlace or CUDAPinnedPlace."));
if (deepcopy) {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
// auto &src_item = fetch_var->Get<phi::TensorArray>();
// phi::TensorArray tmp(src_item.size());
// fetch_list->at(col) = tmp;
// auto &dst_item = PADDLE_GET(phi::TensorArray, fetch_list->at(col));
// for (size_t i = 0; i < src_item.size(); ++i) {
// PADDLE_ENFORCE_EQ(
// src_item[i].place().GetType() == phi::AllocationType::CPU,
// true,
// common::errors::InvalidArgument(
// "Tensor's place of input(X) must be CPUPlace."));
// if (deepcopy) {
// DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
// } else {
// dst_item[i].ShareDataWith(src_item[i]);
// dst_item[i].set_lod(src_item[i].lod());
// }
// }
dst_item->ShareDataWith(src_item);
dst_item->set_lod(src_item.lod());
}
}
};
Expand Down

0 comments on commit 53fc17b

Please sign in to comment.