From 8c56ce6f06fb97c486c47693879cf8ecfa4f1e19 Mon Sep 17 00:00:00 2001 From: Bruce Merry Date: Thu, 2 Sep 2021 19:04:07 +0200 Subject: [PATCH] Make StrictRedis.from_url work Closes #19. --- birdisle/redis.py | 14 ++++++++++++++ tests/test_birdisle.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/birdisle/redis.py b/birdisle/redis.py index 303a6f1..2a597d2 100644 --- a/birdisle/redis.py +++ b/birdisle/redis.py @@ -92,6 +92,20 @@ def __init__(self, server=None, host='localhost', port=6379, ssl_check_hostname, max_connections, single_connection_client, health_check_interval, client_name, username) + @classmethod + def from_url(cls, url, db=None, **kwargs): + server = kwargs.pop('server', None) + if server is None: + server = birdisle.Server() + self = super().from_url(url, db, **kwargs) + self.connection_pool.connection_class = LocalSocketConnection + self.connection_pool.connection_kwargs['server'] = server + # When url is a unix:// URL, connection_kwargs will include 'path', + # but LocalSocketConnection does not expect that (because the base + # class does not). + self.connection_pool.connection_kwargs.pop('path', None) + return self + class StrictRedis(RedisMixin, redis.StrictRedis): """Replacement for :class:`redis.StrictRedis` that connects to a birdisle server. diff --git a/tests/test_birdisle.py b/tests/test_birdisle.py index e681782..76307ce 100644 --- a/tests/test_birdisle.py +++ b/tests/test_birdisle.py @@ -20,7 +20,7 @@ def server(): @pytest.fixture def r(server): redis = birdisle.redis.StrictRedis(server=server) - yield redis + return redis @pytest.fixture @@ -166,6 +166,16 @@ def test_shared_server(server): assert b.get('foo') == b'bar' +@pytest.mark.parametrize('url', ['unix:///some/path/?db=7', 'redis://host.invalid:12345/7']) +def test_from_url(server, url): + r = birdisle.redis.StrictRedis.from_url(url, server=server) + r.set('hello', 'world') + assert r.get('hello') == b'world' + # Check that we're using DB 7 (more specifically, not 0) by swapping it away + r.swapdb(7, 8) + assert r.get('hello') is None + + def test_signals(r, profile_timer): """Test that signal delivery doesn't interfere with birdisle""" def handler(signum, frame):