-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor zmq pub and sub into a zmq backend
- Loading branch information
Showing
13 changed files
with
689 additions
and
535 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import zmq | ||
|
||
from posttroll import config | ||
|
||
def _set_tcp_keepalive(socket): | ||
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE, config.get("tcp_keepalive", None)) | ||
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_CNT, config.get("tcp_keepalive_cnt", None)) | ||
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_IDLE, config.get("tcp_keepalive_idle", None)) | ||
_set_int_sockopt(socket, zmq.TCP_KEEPALIVE_INTVL, config.get("tcp_keepalive_intvl", None)) | ||
|
||
|
||
def _set_int_sockopt(socket, param, value): | ||
if value is not None: | ||
socket.setsockopt(param, int(value)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from threading import Lock | ||
from urllib.parse import urlsplit, urlunsplit | ||
import zmq | ||
import logging | ||
|
||
from posttroll import get_context | ||
from posttroll.backends.zmq import _set_tcp_keepalive | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class UnsecureZMQPublisher: | ||
"""Unsecure ZMQ implementation of the publisher class.""" | ||
|
||
def __init__(self, address, name="", min_port=None, max_port=None): | ||
"""Bind the publisher class to a port.""" | ||
self.name = name | ||
self.destination = address | ||
self.publish_socket = None | ||
self.min_port = min_port | ||
self.max_port = max_port | ||
self.port_number = None | ||
self._pub_lock = Lock() | ||
|
||
def start(self): | ||
"""Start the publisher. | ||
""" | ||
self.publish_socket = get_context().socket(zmq.PUB) | ||
_set_tcp_keepalive(self.publish_socket) | ||
|
||
self._bind() | ||
LOGGER.info(f"Publisher for {self.destination} started on port {self.port_number}.") | ||
return self | ||
|
||
def _bind(self): | ||
# Check for port 0 (random port) | ||
u__ = urlsplit(self.destination) | ||
port = u__.port | ||
if port == 0: | ||
dest = urlunsplit((u__.scheme, u__.hostname, | ||
u__.path, u__.query, u__.fragment)) | ||
self.port_number = self.publish_socket.bind_to_random_port( | ||
dest, | ||
min_port=self.min_port, | ||
max_port=self.max_port) | ||
netloc = u__.hostname + ":" + str(self.port_number) | ||
self.destination = urlunsplit((u__.scheme, netloc, u__.path, | ||
u__.query, u__.fragment)) | ||
else: | ||
self.publish_socket.bind(self.destination) | ||
self.port_number = port | ||
|
||
def send(self, msg): | ||
"""Send the given message.""" | ||
with self._pub_lock: | ||
self.publish_socket.send_string(msg) | ||
|
||
def stop(self): | ||
"""Stop the publisher.""" | ||
self.publish_socket.setsockopt(zmq.LINGER, 1) | ||
self.publish_socket.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
from threading import Lock | ||
from urllib.parse import urlsplit | ||
from posttroll.message import _MAGICK, Message | ||
from zmq import Poller, SUB, SUBSCRIBE, POLLIN, PULL, ZMQError, NOBLOCK, LINGER | ||
from time import sleep | ||
import logging | ||
|
||
from posttroll import get_context | ||
from posttroll.backends.zmq import _set_tcp_keepalive | ||
|
||
|
||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
class UnsecureZMQSubscriber: | ||
"""Unsecure ZMQ implementation of the subscriber.""" | ||
|
||
def __init__(self, addresses, topics='', message_filter=None, translate=False): | ||
"""Initialize the subscriber.""" | ||
self._topics = topics | ||
self._filter = message_filter | ||
self._translate = translate | ||
|
||
self.sub_addr = {} | ||
self.addr_sub = {} | ||
|
||
self._hooks = [] | ||
self._hooks_cb = {} | ||
|
||
self.poller = Poller() | ||
self._lock = Lock() | ||
|
||
self.update(addresses) | ||
|
||
self._loop = None | ||
|
||
def add(self, address, topics=None): | ||
"""Add *address* to the subscribing list for *topics*. | ||
It topics is None we will subscribe to already specified topics. | ||
""" | ||
with self._lock: | ||
if address in self.addresses: | ||
return | ||
|
||
topics = topics or self._topics | ||
LOGGER.info("Subscriber adding address %s with topics %s", | ||
str(address), str(topics)) | ||
subscriber = self._add_sub_socket(address, topics) | ||
self.sub_addr[subscriber] = address | ||
self.addr_sub[address] = subscriber | ||
|
||
def _add_sub_socket(self, address, topics): | ||
subscriber = get_context().socket(SUB) | ||
_set_tcp_keepalive(subscriber) | ||
for t__ in topics: | ||
subscriber.setsockopt_string(SUBSCRIBE, str(t__)) | ||
subscriber.connect(address) | ||
|
||
if self.poller: | ||
self.poller.register(subscriber, POLLIN) | ||
return subscriber | ||
|
||
def remove(self, address): | ||
"""Remove *address* from the subscribing list for *topics*.""" | ||
with self._lock: | ||
try: | ||
subscriber = self.addr_sub[address] | ||
except KeyError: | ||
return | ||
LOGGER.info("Subscriber removing address %s", str(address)) | ||
del self.addr_sub[address] | ||
del self.sub_addr[subscriber] | ||
self._remove_sub_socket(subscriber) | ||
|
||
def _remove_sub_socket(self, subscriber): | ||
if self.poller: | ||
self.poller.unregister(subscriber) | ||
subscriber.close() | ||
|
||
def update(self, addresses): | ||
"""Update with a set of addresses.""" | ||
if isinstance(addresses, str): | ||
addresses = [addresses, ] | ||
current_addresses, new_addresses = set(self.addresses), set(addresses) | ||
addresses_to_remove = current_addresses.difference(new_addresses) | ||
addresses_to_add = new_addresses.difference(current_addresses) | ||
for addr in addresses_to_remove: | ||
self.remove(addr) | ||
for addr in addresses_to_add: | ||
self.add(addr) | ||
return bool(addresses_to_remove or addresses_to_add) | ||
|
||
def add_hook_sub(self, address, topics, callback): | ||
"""Specify a SUB *callback* in the same stream (thread) as the main receive loop. | ||
The callback will be called with the received messages from the | ||
specified subscription. | ||
Good for operations, which is required to be done in the same thread as | ||
the main recieve loop (e.q operations on the underlying sockets). | ||
""" | ||
topics = topics | ||
LOGGER.info("Subscriber adding SUB hook %s for topics %s", | ||
str(address), str(topics)) | ||
socket = self._add_sub_socket(address, topics) | ||
self._add_hook(socket, callback) | ||
|
||
def add_hook_pull(self, address, callback): | ||
"""Specify a PULL *callback* in the same stream (thread) as the main receive loop. | ||
The callback will be called with the received messages from the | ||
specified subscription. Good for pushed 'inproc' messages from another thread. | ||
""" | ||
LOGGER.info("Subscriber adding PULL hook %s", str(address)) | ||
socket = get_context().socket(PULL) | ||
socket.connect(address) | ||
if self.poller: | ||
self.poller.register(socket, POLLIN) | ||
self._add_hook(socket, callback) | ||
|
||
def _add_hook(self, socket, callback): | ||
"""Add a generic hook. The passed socket has to be "receive only".""" | ||
self._hooks.append(socket) | ||
self._hooks_cb[socket] = callback | ||
|
||
|
||
@property | ||
def addresses(self): | ||
"""Get the addresses.""" | ||
return self.sub_addr.values() | ||
|
||
@property | ||
def subscribers(self): | ||
"""Get the subscribers.""" | ||
return self.sub_addr.keys() | ||
|
||
def recv(self, timeout=None): | ||
"""Receive, optionally with *timeout* in seconds.""" | ||
if timeout: | ||
timeout *= 1000. | ||
|
||
for sub in list(self.subscribers) + self._hooks: | ||
self.poller.register(sub, POLLIN) | ||
self._loop = True | ||
try: | ||
while self._loop: | ||
sleep(0) | ||
try: | ||
socks = dict(self.poller.poll(timeout=timeout)) | ||
if socks: | ||
for sub in self.subscribers: | ||
if sub in socks and socks[sub] == POLLIN: | ||
received = sub.recv_string(NOBLOCK) | ||
m__ = Message.decode(received) | ||
if not self._filter or self._filter(m__): | ||
if self._translate: | ||
url = urlsplit(self.sub_addr[sub]) | ||
host = url[1].split(":")[0] | ||
m__.sender = (m__.sender.split("@")[0] | ||
+ "@" + host) | ||
yield m__ | ||
|
||
for sub in self._hooks: | ||
if sub in socks and socks[sub] == POLLIN: | ||
m__ = Message.decode(sub.recv_string(NOBLOCK)) | ||
self._hooks_cb[sub](m__) | ||
else: | ||
# timeout | ||
yield None | ||
except ZMQError as err: | ||
if self._loop: | ||
LOGGER.exception("Receive failed: %s", str(err)) | ||
finally: | ||
for sub in list(self.subscribers) + self._hooks: | ||
self.poller.unregister(sub) | ||
|
||
def __call__(self, **kwargs): | ||
"""Handle calls with class instance.""" | ||
return self.recv(**kwargs) | ||
|
||
def stop(self): | ||
"""Stop the subscriber.""" | ||
self._loop = False | ||
|
||
def close(self): | ||
"""Close the subscriber: stop it and close the local subscribers.""" | ||
self.stop() | ||
for sub in list(self.subscribers) + self._hooks: | ||
try: | ||
sub.setsockopt(LINGER, 1) | ||
sub.close() | ||
except ZMQError: | ||
pass | ||
|
||
def __del__(self): | ||
"""Clean up after the instance is deleted.""" | ||
for sub in list(self.subscribers) + self._hooks: | ||
try: | ||
sub.close() | ||
except Exception: # noqa: E722 | ||
pass |
Oops, something went wrong.