From 1fee8686cc6e34e1a55918fa01894ac2ae6e48fb Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 9 Jan 2025 08:55:33 +0800 Subject: [PATCH 1/2] Fix --- paddle/phi/kernels/impl/margin_cross_entropy.cu.h | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/paddle/phi/kernels/impl/margin_cross_entropy.cu.h b/paddle/phi/kernels/impl/margin_cross_entropy.cu.h index e9590e05d8453..1bdc534423248 100644 --- a/paddle/phi/kernels/impl/margin_cross_entropy.cu.h +++ b/paddle/phi/kernels/impl/margin_cross_entropy.cu.h @@ -94,17 +94,8 @@ void GetClassInterval(const gpuStream_t& stream, phi::distributed::CommContextManager::GetInstance(); phi::distributed::NCCLCommContext* comm_ctx = nullptr; - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library by " - "setting environment " - "variable FLAGS_dynamic_static_unified_comm True. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(rid))); comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(rid))); + dev_ctx.GetCommContext()); PADDLE_ENFORCE_NE(comm_ctx, nullptr, common::errors::Unavailable( From 4430fc1b60ea70d0e5c55c95165b2455e95d376f Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 9 Jan 2025 10:21:04 +0800 Subject: [PATCH 2/2] Fix --- paddle/phi/kernels/impl/margin_cross_entropy.cu.h | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/impl/margin_cross_entropy.cu.h b/paddle/phi/kernels/impl/margin_cross_entropy.cu.h index 1bdc534423248..8103a2fad99f8 100644 --- a/paddle/phi/kernels/impl/margin_cross_entropy.cu.h +++ b/paddle/phi/kernels/impl/margin_cross_entropy.cu.h @@ -90,12 +90,9 @@ void GetClassInterval(const gpuStream_t& stream, auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - const auto& comm_context_manager = - phi::distributed::CommContextManager::GetInstance(); - phi::distributed::NCCLCommContext* comm_ctx = nullptr; - - comm_ctx = static_cast( - dev_ctx.GetCommContext()); + phi::distributed::NCCLCommContext* comm_ctx = + static_cast( + dev_ctx.GetCommContext()); PADDLE_ENFORCE_NE(comm_ctx, nullptr, common::errors::Unavailable(