Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ipv6 #39

Open
wants to merge 4 commits into
base: byteps
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions src/rdma_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#ifdef DMLC_USE_RDMA

#include <algorithm>
#include "rdma_utils.h"
#include "rdma_transport.h"

Expand Down Expand Up @@ -125,19 +126,32 @@ class RDMAVan : public Van {
struct sockaddr_in addr;
algs marked this conversation as resolved.
Show resolved Hide resolved
memset(&addr, 0, sizeof(addr));

int af = PF_INET;
algs marked this conversation as resolved.
Show resolved Hide resolved
int ret = -EINVAL;
struct addrinfo *res;

auto val = Environment::Get()->find("DMLC_NODE_HOST");
std::string val_str = std::string(val);
if (val) {
PS_VLOG(1) << "bind to DMLC_NODE_HOST: " << std::string(val);
PS_VLOG(1) << "bind to DMLC_NODE_HOST: " << val_str;
std::size_t n = std::count(val_str.begin(), val_str.end(), ':');
if (n > 1) {
af = PF_INET6;
algs marked this conversation as resolved.
Show resolved Hide resolved
}
addr.sin_addr.s_addr = inet_addr(val);
}

addr.sin_family = AF_INET;
// addr.sin_family = AF_INET;
addr.sin_family = af;
int port = node.port;
addr.sin_port = htons(port);
ret = getaddrinfo(val_str.c_str(), std::to_string(port).c_str(), NULL, &res);
CHECK(ret >= 0) << "could not getaddrinfo address " << val_str << " error code " << ret;
unsigned seed = static_cast<unsigned>(time(NULL) + port);
for (int i = 0; i < max_retry + 1; ++i) {
addr.sin_port = htons(port);
if (rdma_bind_addr(listener_,
reinterpret_cast<struct sockaddr *>(&addr)) == 0) {
// if (rdma_bind_addr(listener_,
// reinterpret_cast<struct sockaddr *>(&addr)) == 0) {
if (rdma_bind_addr(listener_, res->ai_addr) == 0) {
break;
}
if (i == max_retry) {
Expand Down Expand Up @@ -207,11 +221,11 @@ class RDMAVan : public Van {
CHECK_EQ(rc, 0) << "getaddrinfo failed: " << gai_strerror(rc);

CHECK_EQ(rdma_resolve_addr(endpoint->cm_id, addr->ai_addr,
remote_addr->ai_addr, kTimeoutms), 0)
(struct sockaddr *)remote_addr->ai_addr, kTimeoutms), 0)
<< "Resolve RDMA address failed with errno: " << strerror(errno);
} else {
CHECK_EQ(rdma_resolve_addr(endpoint->cm_id, nullptr,
remote_addr->ai_addr, kTimeoutms),
(struct sockaddr *)remote_addr->ai_addr, kTimeoutms),
0)
<< "Resolve RDMA address failed with errno: " << strerror(errno);
}
Expand Down
28 changes: 23 additions & 5 deletions src/zmq_van.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#define PS_ZMQ_VAN_H_
#include <stdio.h>
#include <cstdlib>
#include <algorithm>
#include <zmq.h>
#include <string>
#include <cstring>
Expand Down Expand Up @@ -98,15 +99,21 @@ class ZMQVan : public Van {
int Bind(const Node& node, int max_retry) override {
receiver_ = zmq_socket(context_, ZMQ_ROUTER);
int option = 1;
std::string hostname = node.hostname.empty() ? "*" : node.hostname;
size_t n = std::count(hostname.begin(), hostname.end(), ':');
CHECK(!zmq_setsockopt(receiver_, ZMQ_ROUTER_MANDATORY, &option, sizeof(option)))
<< zmq_strerror(errno);
CHECK(receiver_ != NULL)
<< "create receiver socket failed: " << zmq_strerror(errno);
int local = GetEnv("DMLC_LOCAL", 0);
std::string hostname = node.hostname.empty() ? "*" : node.hostname;
int use_kubernetes = GetEnv("DMLC_USE_KUBERNETES", 0);
if (use_kubernetes > 0 && node.role == Node::SCHEDULER) {
hostname = "0.0.0.0";
hostname = (n > 1) ? "::/0" : "0.0.0.0";
}
if (n > 1) {
CHECK(!zmq_setsockopt(receiver_, ZMQ_IPV6, &option, sizeof(option)))
<< zmq_strerror(errno);
PS_VLOG(1) << "bind IPv6 socket to host " << hostname;
}
std::string addr = local ? "ipc:///tmp/" : "tcp://" + hostname + ":";
int port = node.port;
Expand All @@ -117,9 +124,9 @@ class ZMQVan : public Van {
if (ret == 0) break;
if (i == max_retry) {
port = -1;
int zmq_err = zmq_errno();
LOG(FATAL) << "Reached max retry for bind: " << zmq_strerror(zmq_err)
<< ". errno = " << zmq_err;
int zmq_err = zmq_errno();
LOG(FATAL) << "Reached max retry for bind: " << zmq_strerror(zmq_err)
<< ". errno = " << zmq_err;
} else {
port = 10000 + rand_r(&seed) % 40000;
}
Expand All @@ -137,6 +144,7 @@ class ZMQVan : public Van {
CHECK_NE(node.port, node.kEmpty);
CHECK(node.hostname.size());
int id = node.id;
bool is_ipv6 = false;
mu_.lock();
auto it = senders_.find(id);
if (it != senders_.end()) {
Expand All @@ -155,6 +163,16 @@ class ZMQVan : public Van {
<< zmq_strerror(errno)
<< ". it often can be solved by \"sudo ulimit -n 65536\""
<< " or edit /etc/security/limits.conf";
std::string hostname = node.hostname.empty() ? "*" : node.hostname;
size_t n = std::count(hostname.begin(), hostname.end(), ':');
PS_VLOG(1) << "connect to host " << hostname;
if (n > 1) {
int option = 1;
is_ipv6 = true;
PS_VLOG(1) << "connect with ipv6 socket";
CHECK(!zmq_setsockopt(sender, ZMQ_IPV6, &option, sizeof(option)))
<< zmq_strerror(errno);
}
if (my_node_.id != Node::kEmpty) {
std::string my_id = "ps" + std::to_string(my_node_.id);
zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());
Expand Down