From 73e8a8c734ad03b647289985de1ac8e18d51ed3a Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 20 Jan 2025 21:02:59 +0800 Subject: [PATCH] Revert "[clean old comm][fluid_ops]c_allreduce_op.h (#70732)" This reverts commit f70042ab551078635fa5b28738aca23a6f298cd9. --- .../operators/collective/c_allreduce_op.h | 85 +++++++++++-------- .../operators/collective/c_gen_bkcl_id_op.cc | 18 ++++ .../operators/collective/c_gen_nccl_id_op.cc | 18 ++++ .../operators/collective/c_wait_comm_op.cc | 40 ++++++--- .../operators/collective/c_wait_compute_op.cc | 40 ++++++--- .../operators/collective/recv_v2_op.cu.cc | 49 +++++++---- .../operators/collective/send_v2_op.cu.cc | 49 +++++++---- 7 files changed, 201 insertions(+), 98 deletions(-) diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 9430527e4caf13..65b9e8de56a011 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -29,6 +29,7 @@ limitations under the License. */ defined(PADDLE_WITH_XPU_BKCL) #include "paddle/common/flags.h" #include "paddle/phi/core/platform/collective_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -175,24 +176,30 @@ class CAllReduceOpXPUKernel : public framework::OpKernel { const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(rid))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(rid))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - common::errors::Unavailable( - "BKCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - stream = comm_ctx->GetStream(); - VLOG(3) << "new comm_context_manager has rid " << rid; - + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + common::errors::Unavailable( + "BKCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has rid " << rid; + } else { + comm = platform::BKCLCommContext::Instance().Get(rid, place); + stream = comm->stream(); + VLOG(3) << "old BKCLCommContext has rid " << rid; + } if (ctx.Attr("use_calc_stream")) { auto dev_ctx = phi::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->x_context()->xpu_stream; @@ -302,24 +309,30 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(rid))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(rid))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - common::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - stream = comm_ctx->GetStream(); - VLOG(3) << "new comm_context_manager has rid " << rid; - + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + common::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has rid " << rid; + } else { + comm = platform::NCCLCommContext::Instance().Get(rid, place); + stream = comm->stream(); + VLOG(3) << "old NCCLCommContext has rid " << rid; + } if (ctx.Attr("use_calc_stream")) { // should not use global ctx for calc stream. // auto dev_ctx = phi::DeviceContextPool::Instance().Get(place); diff --git a/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc b/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc index 3479562f93ae55..324cdde5175c4e 100644 --- a/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc @@ -24,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/platform/device_context.h" #include "paddle/phi/core/platform/gen_comm_id_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { @@ -62,13 +63,30 @@ class CGenBKCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const phi::Place& dev_place) const override { + int rank = Attr("rank"); + int ring_id = Attr("ring_id"); + std::function func = [&](size_t i) -> std::string { return Output("Out"); }; + std::string endpoint = Attr("endpoint"); + std::vector bkcl_ids; bkcl_ids.resize(1); + if (!FLAGS_dynamic_static_unified_comm) { + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + if (rank == 0) { + GenBKCLID(&bkcl_ids); + std::vector endpoint_list = + Attr>("other_endpoints"); + platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id); + } else { + platform::RecvBroadCastCommID(server_fd, endpoint, &bkcl_ids, ring_id); + } + } + CopyBKCLIDToVar(bkcl_ids, func, scope); } }; diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index beda7cf0c1377b..5004439695097f 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/core/platform/device_context.h" #include "paddle/phi/core/platform/gen_comm_id_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle::operators { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -57,13 +58,30 @@ class CGenNCCLIdOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const phi::Place& dev_place) const override { + int rank = Attr("rank"); + int ring_id = Attr("ring_id"); + std::function func = [&](size_t i) -> std::string { return Output("Out"); }; + std::string endpoint = Attr("endpoint"); + std::vector nccl_ids; nccl_ids.resize(1); + if (!FLAGS_dynamic_static_unified_comm) { + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + if (rank == 0) { + GenNCCLID(&nccl_ids); + std::vector endpoint_list = + Attr>("other_endpoints"); + platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id); + } else { + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id); + } + } + CopyNCCLIDToVar(nccl_ids, func, scope); } }; diff --git a/paddle/fluid/operators/collective/c_wait_comm_op.cc b/paddle/fluid/operators/collective/c_wait_comm_op.cc index 8226f6d1d495e2..ce9387d5aea183 100644 --- a/paddle/fluid/operators/collective/c_wait_comm_op.cc +++ b/paddle/fluid/operators/collective/c_wait_comm_op.cc @@ -22,6 +22,7 @@ class Scope; #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/platform/collective_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif namespace paddle::operators { @@ -55,20 +56,31 @@ class CWaitCommOp : public framework::OperatorBase { const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(ring_id))); - phi::distributed::NCCLCommContext* comm_ctx = - static_cast( - comm_context_manager.Get(std::to_string(ring_id))); - comm_stream = comm_ctx->GetStream(); - event = comm_ctx->GetComputeEvent(); - VLOG(3) << "new comm_context_manager has rid " << ring_id; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + phi::distributed::NCCLCommContext* comm_ctx = + static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + comm_stream = comm_ctx->GetStream(); + event = comm_ctx->GetComputeEvent(); + VLOG(3) << "new comm_context_manager has rid " << ring_id; + } else { + comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + + event = platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->comm_event(); + VLOG(3) << "old NCCLCommContext has rid " << ring_id; + } // comm_stream-->event-->compute_stream #ifdef PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/collective/c_wait_compute_op.cc b/paddle/fluid/operators/collective/c_wait_compute_op.cc index 234832a6c46059..4d8a5f158c679b 100644 --- a/paddle/fluid/operators/collective/c_wait_compute_op.cc +++ b/paddle/fluid/operators/collective/c_wait_compute_op.cc @@ -22,6 +22,7 @@ class Scope; #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/platform/collective_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif namespace paddle::operators { @@ -55,20 +56,31 @@ class CWaitComputeOp : public framework::OperatorBase { const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(ring_id))); - phi::distributed::NCCLCommContext* comm_ctx = - static_cast( - comm_context_manager.Get(std::to_string(ring_id))); - comm_stream = comm_ctx->GetStream(); - event = comm_ctx->GetComputeEvent(); - VLOG(3) << "new comm_context_manager has rid " << ring_id; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + phi::distributed::NCCLCommContext* comm_ctx = + static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + comm_stream = comm_ctx->GetStream(); + event = comm_ctx->GetComputeEvent(); + VLOG(3) << "new comm_context_manager has rid " << ring_id; + } else { + comm_stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + + event = platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->compute_event(); + VLOG(3) << "old NCCLCommContext has rid " << ring_id; + } // compute_stream-->event-->comm_stream #ifdef PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 460331f03dab58..b2866e14ea6782 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/platform/collective_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/fluid/distributed/collective/process_group.h" @@ -161,23 +162,37 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { const auto &comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(rid))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(rid))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - common::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - stream = comm_ctx->GetStream(); - VLOG(3) << "new comm_context_manager has rid " << rid; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + common::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has rid " << rid; + } else { + comm = platform::NCCLCommContext::Instance().Get(rid, place); + PADDLE_ENFORCE_LT( + peer, + comm->nranks(), + common::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, + comm->nranks())); + stream = comm->stream(); + VLOG(3) << "old NCCLCommContext has rid " << rid; + } if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream. diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 36901777e27d1e..7a9861bf9d6213 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/platform/collective_helper.h" +COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/phi/api/include/tensor.h" @@ -145,23 +146,37 @@ class SendOpV2CUDAKernel : public framework::OpKernel { const auto& comm_context_manager = phi::distributed::CommContextManager::GetInstance(); - - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(rid))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(rid))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - common::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - stream = comm_ctx->GetStream(); - VLOG(3) << "new comm_context_manager has rid " << rid; + if (FLAGS_dynamic_static_unified_comm) { + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(rid))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(rid))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + common::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has rid " << rid; + } else { + comm = platform::NCCLCommContext::Instance().Get(rid, place); + PADDLE_ENFORCE_LT( + peer, + comm->nranks(), + common::errors::InvalidArgument("The value of peer (%d) you set must " + "be less than comm->nranks (%d).", + peer, + comm->nranks())); + stream = comm->stream(); + VLOG(3) << "old NCCLCommContext has rid " << rid; + } if (ctx.Attr("use_calc_stream")) { // should ExecutionContext for calc stream.