Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[clean old comm][fluid_ops]c_allreduce_op.h" #70906

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions paddle/fluid/operators/collective/c_allreduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/common/flags.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down Expand Up @@ -175,24 +176,30 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::BKCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"BKCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;

if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::BKCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"BKCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::BKCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old BKCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
stream = static_cast<phi::XPUContext*>(dev_ctx)->x_context()->xpu_stream;
Expand Down Expand Up @@ -302,24 +309,30 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;

if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}
if (ctx.Attr<bool>("use_calc_stream")) {
// should not use global ctx for calc stream.
// auto dev_ctx = phi::DeviceContextPool::Instance().Get(place);
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/phi/core/platform/device_context.h"
#include "paddle/phi/core/platform/gen_comm_id_helper.h"

COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle {
namespace operators {

Expand Down Expand Up @@ -62,13 +63,30 @@ class CGenBKCLIdOp : public framework::OperatorBase {

void RunImpl(const framework::Scope& scope,
const phi::Place& dev_place) const override {
int rank = Attr<int>("rank");
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};

std::string endpoint = Attr<std::string>("endpoint");

std::vector<BKCLUniqueId> bkcl_ids;
bkcl_ids.resize(1);

if (!FLAGS_dynamic_static_unified_comm) {
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
if (rank == 0) {
GenBKCLID(&bkcl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &bkcl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &bkcl_ids, ring_id);
}
}

CopyBKCLIDToVar(bkcl_ids, func, scope);
}
};
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/phi/core/platform/device_context.h"
#include "paddle/phi/core/platform/gen_comm_id_helper.h"

COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace paddle::operators {

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down Expand Up @@ -57,13 +58,30 @@ class CGenNCCLIdOp : public framework::OperatorBase {

void RunImpl(const framework::Scope& scope,
const phi::Place& dev_place) const override {
int rank = Attr<int>("rank");
int ring_id = Attr<int>("ring_id");

std::function<std::string(size_t)> func = [&](size_t i) -> std::string {
return Output("Out");
};

std::string endpoint = Attr<std::string>("endpoint");

std::vector<ncclUniqueId> nccl_ids;
nccl_ids.resize(1);

if (!FLAGS_dynamic_static_unified_comm) {
int server_fd = platform::SocketServer::GetInstance(endpoint).socket();
if (rank == 0) {
GenNCCLID(&nccl_ids);
std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("other_endpoints");
platform::SendBroadCastCommID(endpoint_list, &nccl_ids, ring_id);
} else {
platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids, ring_id);
}
}

CopyNCCLIDToVar(nccl_ids, func, scope);
}
};
Expand Down
40 changes: 26 additions & 14 deletions paddle/fluid/operators/collective/c_wait_comm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Scope;
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle::operators {
Expand Down Expand Up @@ -55,20 +56,31 @@ class CWaitCommOp : public framework::OperatorBase {

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();

event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->comm_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}

// comm_stream-->event-->compute_stream
#ifdef PADDLE_WITH_HIP
Expand Down
40 changes: 26 additions & 14 deletions paddle/fluid/operators/collective/c_wait_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Scope;
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

namespace paddle::operators {
Expand Down Expand Up @@ -55,20 +56,31 @@ class CWaitComputeOp : public framework::OperatorBase {

const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(ring_id)));
phi::distributed::NCCLCommContext* comm_ctx =
static_cast<phi::distributed::NCCLCommContext*>(
comm_context_manager.Get(std::to_string(ring_id)));
comm_stream = comm_ctx->GetStream();
event = comm_ctx->GetComputeEvent();
VLOG(3) << "new comm_context_manager has rid " << ring_id;
} else {
comm_stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();

event = platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->compute_event();
VLOG(3) << "old NCCLCommContext has rid " << ring_id;
}

// compute_stream-->event-->comm_stream
#ifdef PADDLE_WITH_HIP
Expand Down
49 changes: 32 additions & 17 deletions paddle/fluid/operators/collective/recv_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/distributed/comm_context_manager.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
#include "paddle/phi/core/platform/collective_helper.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#include "paddle/fluid/distributed/collective/process_group.h"
Expand Down Expand Up @@ -161,23 +162,37 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {

const auto &comm_context_manager =
phi::distributed::CommContextManager::GetInstance();

PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
if (FLAGS_dynamic_static_unified_comm) {
PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(rid)),
true,
common::errors::InvalidArgument(
"You choose to use new communication library by "
"setting environment "
"variable FLAGS_dynamic_static_unified_comm True. "
"But ring_id(%d) is "
"not found in comm_context_manager.",
std::to_string(rid)));
comm_ctx = static_cast<phi::distributed::NCCLCommContext *>(
comm_context_manager.Get(std::to_string(rid)));
PADDLE_ENFORCE_NE(comm_ctx,
nullptr,
common::errors::Unavailable(
"NCCLCommContext is nullptr, collective op should "
"has ring_id attr."));
stream = comm_ctx->GetStream();
VLOG(3) << "new comm_context_manager has rid " << rid;
} else {
comm = platform::NCCLCommContext::Instance().Get(rid, place);
PADDLE_ENFORCE_LT(
peer,
comm->nranks(),
common::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).",
peer,
comm->nranks()));
stream = comm->stream();
VLOG(3) << "old NCCLCommContext has rid " << rid;
}

if (ctx.Attr<bool>("use_calc_stream")) {
// should ExecutionContext for calc stream.
Expand Down
Loading
Loading