Skip to content

Commit

Permalink
!16944 Fix Server s_addr use INADDR_ANY
Browse files Browse the repository at this point in the history
Merge pull request !16944 from wuxiaotong/cherry-pick-1734078855
  • Loading branch information
wuxiaotong authored and it-is-a-robot committed Dec 14, 2024
1 parent e47a99a commit 5f2e931
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 18 deletions.
19 changes: 13 additions & 6 deletions torch_npu/csrc/distributed/ParallelTcpServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include <netinet/in.h>
#include <unistd.h>
#include <fcntl.h>

#include <arpa/inet.h>
#include "c10/util/Logging.h"
#include "ParallelTcpServer.hpp"

Expand Down Expand Up @@ -102,8 +102,15 @@ void ClientIoContext::FlushSendBuf() noexcept
}


ParallelTcpServer::ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept
: threadNum_{ std::max(4U, threadNum) }, port_{ port }, process_{ std::move(process) }
ParallelTcpServer::ParallelTcpServer(
uint32_t threadNum,
const std::string host,
uint16_t port,
ServerProcFn process) noexcept
: threadNum_{ std::max(4U, threadNum) },
host_{ host },
port_{ port },
process_{ std::move(process) }
{}

int ParallelTcpServer::Start() noexcept
Expand All @@ -114,7 +121,7 @@ int ParallelTcpServer::Start() noexcept
return -1;
}

listenSocket_ = CreateSocket(port_);
listenSocket_ = CreateSocket(host_, port_);
if (listenSocket_ < 0) {
delete[] buffer_;
buffer_ = nullptr;
Expand Down Expand Up @@ -204,11 +211,11 @@ void ParallelTcpServer::WakeupWaitingClients(const std::string &key) noexcept
}
}

int ParallelTcpServer::CreateSocket(uint16_t port) noexcept
int ParallelTcpServer::CreateSocket(const std::string host, uint16_t port) noexcept
{
struct sockaddr_in servAddr {};
servAddr.sin_family = AF_INET;
servAddr.sin_addr.s_addr = INADDR_ANY;
servAddr.sin_addr.s_addr = inet_addr(host.c_str());
servAddr.sin_port = htons(port);

auto sockFd = ::socket(AF_INET, SOCK_STREAM, 0);
Expand Down
5 changes: 3 additions & 2 deletions torch_npu/csrc/distributed/ParallelTcpServer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ using ServerProcFn = std::function<StoreMessage(int fd, const StoreMessage &req)
*/
class ParallelTcpServer {
public:
explicit ParallelTcpServer(uint32_t threadNum, uint16_t port, ServerProcFn process) noexcept;
explicit ParallelTcpServer(uint32_t threadNum, const std::string host, uint16_t port, ServerProcFn process) noexcept;

int Start() noexcept;

Expand All @@ -117,7 +117,7 @@ class ParallelTcpServer {
void WakeupWaitingClients(const std::string &key) noexcept;

private:
static int CreateSocket(uint16_t port) noexcept;
static int CreateSocket(const std::string host, uint16_t port) noexcept;

static int CreateEpoll(int targetFd = -1) noexcept;

Expand All @@ -134,6 +134,7 @@ class ParallelTcpServer {
private:
const uint32_t threadNum_;
const std::uint16_t port_;
const std::string host_;
const ServerProcFn process_;
int listenSocket_{ -1 };
int epCtlFd_{ -1 };
Expand Down
15 changes: 8 additions & 7 deletions torch_npu/csrc/distributed/ParallelTcpStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

namespace c10d {
namespace pta {
ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10::optional<std::size_t> numWorkers)
ParallelStoreServer::ParallelStoreServer(std::string initKey, const std::string host, uint16_t port,
c10::optional<std::size_t> numWorkers)
: initKey_{ std::move(initKey) }, numWorkers_{ numWorkers }
{
auto threadNum = 4U;
Expand All @@ -33,7 +34,7 @@ ParallelStoreServer::ParallelStoreServer(std::string initKey, uint16_t port, c10
}

InitializeHandlers();
server_ = std::make_unique<pta::ParallelTcpServer>(threadNum, port,
server_ = std::make_unique<pta::ParallelTcpServer>(threadNum, host, port,
[this](int fd, const pta::StoreMessage &request) { return ProcessRequest(fd, request); });
if (server_->Start() != 0) {
throw std::runtime_error{
Expand Down Expand Up @@ -257,9 +258,9 @@ ParallelTcpStore::ParallelTcpStore(const std::string &host, const c10d::TCPStore
{
if (opts.isServer) {
if (opts.multiTenant) {
server_ = GetSharedServer(initKey_, opts.port, opts.numWorkers);
server_ = GetSharedServer(initKey_, host, opts.port, opts.numWorkers);
} else {
server_ = std::make_shared<pta::ParallelStoreServer>(initKey_, opts.port, opts.numWorkers);
server_ = std::make_shared<pta::ParallelStoreServer>(initKey_, host, opts.port, opts.numWorkers);
}
}

Expand Down Expand Up @@ -411,8 +412,8 @@ void ParallelTcpStore::DoWait(const pta::StoreMessage &req, pta::StoreMessage &r
}
}

std::shared_ptr<pta::ParallelStoreServer> ParallelTcpStore::GetSharedServer(const std::string &initKey, uint16_t port,
c10::optional<std::size_t> numWorkers)
std::shared_ptr<pta::ParallelStoreServer> ParallelTcpStore::GetSharedServer(const std::string &initKey,
const std::string host, uint16_t port, c10::optional<std::size_t> numWorkers)
{
std::unique_lock<std::mutex> lockGuard{ cacheServerMutex_ };
auto pos = cachedServers_.find(port);
Expand All @@ -425,7 +426,7 @@ std::shared_ptr<pta::ParallelStoreServer> ParallelTcpStore::GetSharedServer(cons
cachedServers_.erase(pos);
}

auto server = std::make_shared<pta::ParallelStoreServer>(initKey, port, numWorkers);
auto server = std::make_shared<pta::ParallelStoreServer>(initKey, host, port, numWorkers);
cachedServers_.emplace(port, server);
return server;
}
Expand Down
7 changes: 4 additions & 3 deletions torch_npu/csrc/distributed/ParallelTcpStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ namespace c10d {
namespace pta {
class ParallelStoreServer {
public:
ParallelStoreServer(std::string initKey, uint16_t port, c10::optional<std::size_t> numWorkers);
ParallelStoreServer(std::string initKey, const std::string host, uint16_t port,
c10::optional<std::size_t> numWorkers);
virtual ~ParallelStoreServer() noexcept;
void WaitWorkers(const std::chrono::milliseconds &timeout) noexcept;

Expand Down Expand Up @@ -85,8 +86,8 @@ class ParallelTcpStore : public Store {
private:
int64_t IncreaseKey(const std::string &key, int64_t value);
void DoWait(const pta::StoreMessage &req, pta::StoreMessage &res);
static std::shared_ptr<pta::ParallelStoreServer> GetSharedServer(const std::string &initKey, uint16_t port,
c10::optional<std::size_t> numWorkers);
static std::shared_ptr<pta::ParallelStoreServer> GetSharedServer(const std::string &initKey,
const std::string host, uint16_t port, c10::optional<std::size_t> numWorkers);

private:
pta::TcpClient client_;
Expand Down

0 comments on commit 5f2e931

Please sign in to comment.