Skip to content
This repository has been archived by the owner on Apr 5, 2023. It is now read-only.

Commit

Permalink
Make StrictRedis.from_url work
Browse files Browse the repository at this point in the history
Closes #19.
  • Loading branch information
bmerry committed Sep 2, 2021
1 parent 7175172 commit 8c56ce6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
14 changes: 14 additions & 0 deletions birdisle/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ def __init__(self, server=None, host='localhost', port=6379,
ssl_check_hostname, max_connections, single_connection_client,
health_check_interval, client_name, username)

@classmethod
def from_url(cls, url, db=None, **kwargs):
server = kwargs.pop('server', None)
if server is None:
server = birdisle.Server()
self = super().from_url(url, db, **kwargs)
self.connection_pool.connection_class = LocalSocketConnection
self.connection_pool.connection_kwargs['server'] = server
# When url is a unix:// URL, connection_kwargs will include 'path',
# but LocalSocketConnection does not expect that (because the base
# class does not).
self.connection_pool.connection_kwargs.pop('path', None)
return self


class StrictRedis(RedisMixin, redis.StrictRedis):
"""Replacement for :class:`redis.StrictRedis` that connects to a birdisle server.
Expand Down
12 changes: 11 additions & 1 deletion tests/test_birdisle.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def server():
@pytest.fixture
def r(server):
redis = birdisle.redis.StrictRedis(server=server)
yield redis
return redis


@pytest.fixture
Expand Down Expand Up @@ -166,6 +166,16 @@ def test_shared_server(server):
assert b.get('foo') == b'bar'


@pytest.mark.parametrize('url', ['unix:///some/path/?db=7', 'redis://host.invalid:12345/7'])
def test_from_url(server, url):
r = birdisle.redis.StrictRedis.from_url(url, server=server)
r.set('hello', 'world')
assert r.get('hello') == b'world'
# Check that we're using DB 7 (more specifically, not 0) by swapping it away
r.swapdb(7, 8)
assert r.get('hello') is None


def test_signals(r, profile_timer):
"""Test that signal delivery doesn't interfere with birdisle"""
def handler(signum, frame):
Expand Down

0 comments on commit 8c56ce6

Please sign in to comment.