Skip to content

Commit

Permalink
pool coroutine is a proper backend for aioredis
Browse files Browse the repository at this point in the history
  • Loading branch information
youknowone committed Jun 24, 2018
1 parent f7172a7 commit 82003ba
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 49 deletions.
60 changes: 52 additions & 8 deletions ring/func/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,28 @@
inspect, 'iscoroutinefunction', lambda f: False)


class SingletonCoroutineProxy(object):

def __init__(self, awaitable):
if not asyncio.iscoroutine(awaitable):
raise TypeError(
"StorageProxy requires an awaitable object but '{}' found"
.format(type(awaitable)))
self.awaitable = awaitable
self.singleton = None

def __iter__(self):
if self.singleton is None:
if hasattr(self.awaitable, '__await__'):
awaitable = self.awaitable.__await__()
else:
awaitable = self.awaitable
self.singleton = yield from awaitable
return self.singleton

__await__ = __iter__


class NonAsyncioFactoryProxyBase(fbase.FactoryProxyBase):

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -333,41 +355,60 @@ class AioredisStorage(
CommonMixinStorage, fbase.StorageMixin, BulkStorageMixin):
"""Storage implementation for :class:`aioredis.Redis`."""

@asyncio.coroutine
def _get_backend(self):
if isinstance(self.backend, SingletonCoroutineProxy):
self.backend = yield from self.backend
return self.backend

@asyncio.coroutine
def get_value(self, key):
value = yield from self.backend.get(key)
backend = yield from self._get_backend()
value = yield from backend.get(key)
if value is None:
raise fbase.NotFound
return value

@asyncio.coroutine
def set_value(self, key, value, expire):
return self.backend.set(key, value, expire=expire)
backend = yield from self._get_backend()
result = yield from backend.set(key, value, expire=expire)
return result

@asyncio.coroutine
def delete_value(self, key):
return self.backend.delete(key)
backend = yield from self._get_backend()
result = yield from backend.delete(key)
return result

@asyncio.coroutine
def has_value(self, key):
result = yield from self.backend.exists(key)
backend = yield from self._get_backend()
result = yield from backend.exists(key)
return bool(result)

@asyncio.coroutine
def touch_value(self, key, expire):
if expire is None:
raise TypeError("'touch' is requested for persistent cache")
return self.backend.expire(key, expire)
backend = yield from self._get_backend()
result = yield from backend.expire(key, expire)
return result

@asyncio.coroutine
def get_many_values(self, keys):
values = yield from self.backend.mget(*keys)
backend = yield from self._get_backend()
values = yield from backend.mget(*keys)
return [v if v is not None else fbase.NotFound for v in values]

@asyncio.coroutine
def set_many_values(self, keys, values, expire):
params = itertools.chain.from_iterable(zip(keys, values))
yield from self.backend.mset(*params)
backend = yield from self._get_backend()
yield from backend.mset(*params)
if expire is not None:
asyncio.ensure_future(asyncio.gather(*(
self.backend.expire(key, expire) for key in keys)))
backend.expire(key, expire) for key in keys)))


def dict(
Expand Down Expand Up @@ -461,6 +502,9 @@ def aioredis(
:see: :func:`ring.redis` for non-asyncio version.
"""
if inspect.iscoroutine(redis):
redis = SingletonCoroutineProxy(redis)

return fbase.factory(
redis, key_prefix=key_prefix, on_manufactured=factory_doctor,
user_interface=user_interface, storage_class=storage_class,
Expand Down
4 changes: 2 additions & 2 deletions tests/_test_func_async_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

@pytest.mark.asyncio
async def test_async_def_vanilla_function(aiomcache_client):
storage = await aiomcache_client
storage, storage_ring = aiomcache_client

with pytest.raises(TypeError):
@storage.ring(storage)
@storage_ring(storage)
def vanilla_function():
pass

Expand Down
86 changes: 47 additions & 39 deletions tests/_test_func_asyncio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
import sys
import shelve

import aiomcache
Expand All @@ -11,35 +12,24 @@


@pytest.fixture()
@asyncio.coroutine
def storage_dict():
storage = StorageDict()
storage.ring = ring.dict
return storage
return storage, ring.dict


@pytest.fixture()
@asyncio.coroutine
def aiomcache_client():
client = aiomcache.Client('127.0.0.1', 11211)
client.ring = ring.func.asyncio.aiomcache
return client
return client, ring.func.asyncio.aiomcache


@pytest.fixture()
@asyncio.coroutine
def aioredis_pool():
import sys

if sys.version_info >= (3, 5):
import aioredis

global _aioredis_pool
_aioredis_pool = yield from aioredis.create_redis_pool(
pool_coroutine = aioredis.create_redis_pool(
('localhost', 6379), minsize=2, maxsize=2)
_aioredis_pool.ring = ring.redis
return _aioredis_pool

return pool_coroutine, ring.aioredis
else:
pytest.skip()

Expand All @@ -49,57 +39,76 @@ def aioredis_pool():
pytest.lazy_fixture('aiomcache_client'),
pytest.lazy_fixture('aioredis_pool'),
])
def gen_storage(request):
def storage_and_ring(request):
return request.param


@pytest.fixture()
def storage_shelve():
storage = shelve.open('/tmp/ring-test/shelvea')
storage.ring = ring.shelve
return storage
return storage, ring.shelve


@pytest.fixture()
def storage_disk(request):
client = diskcache.Cache('/tmp/ring-test/diskcache')
client.ring = ring.disk
return client
return client, ring.disk


@pytest.fixture(params=[
pytest.lazy_fixture('storage_shelve'),
pytest.lazy_fixture('storage_disk'),
])
def synchronous_storage(request):
def synchronous_storage_and_ring(request):
return request.param


@pytest.mark.asyncio
@asyncio.coroutine
def test_singleton_proxy():

@asyncio.coroutine
def client():
return object()

assert ((yield from client())) is not ((yield from client()))

proxy = ring.func.asyncio.SingletonCoroutineProxy(client())
assert ((yield from proxy)) is ((yield from proxy))


@pytest.mark.asyncio
@asyncio.coroutine
def test_vanilla_function(aiomcache_client):
storage = yield from aiomcache_client
storage, storage_ring = aiomcache_client

with pytest.raises(TypeError):
@storage.ring(storage)
@storage_ring(storage)
def vanilla_function():
pass


@pytest.mark.asyncio
@asyncio.coroutine
def test_common(gen_storage):
storage = yield from gen_storage
def test_common(storage_and_ring):
storage, storage_ring = storage_and_ring
base = [0]

@storage.ring(storage, 'ring-test !@#', 5)
@storage_ring(storage, 'ring-test !@#', 5)
@asyncio.coroutine
def f(a, b):
return str(base[0] + a * 100 + b).encode()

# `f` is a callable with argument `a` and `b`
# test f is correct
assert f.storage.backend is storage
if asyncio.iscoroutine(storage):
s1 = yield from f.storage.backend
with pytest.raises(RuntimeError):
yield from storage
s2 = yield from f.storage.backend
assert s1 is s2
else:
assert f.storage.backend is storage
assert f.key(a=0, b=0) # f takes a, b
assert base[0] is not None # f has attr base for test
assert ((yield from f.execute(a=1, b=2))) != ((yield from f.execute(a=1, b=3))) # f is not singular
Expand Down Expand Up @@ -153,11 +162,10 @@ def f(a, b):

@pytest.mark.asyncio
@asyncio.coroutine
def test_complicated_key(gen_storage):

storage = yield from gen_storage
def test_complicated_key(storage_and_ring):
storage, storage_ring = storage_and_ring

@storage.ring(storage)
@storage_ring(storage)
@asyncio.coroutine
def complicated(a, *args, b, **kw):
return b'42'
Expand Down Expand Up @@ -207,7 +215,7 @@ def f2(a, b):
@pytest.mark.asyncio
@asyncio.coroutine
def test_many(aiomcache_client):
client = yield from aiomcache_client
client, _ = aiomcache_client

@ring.aiomcache(client)
@asyncio.coroutine
Expand All @@ -229,7 +237,7 @@ def f(a):
@pytest.mark.asyncio
@asyncio.coroutine
def test_aiomcache(aiomcache_client):
client = yield from aiomcache_client
client, _ = aiomcache_client

@ring.aiomcache(client)
@asyncio.coroutine
Expand Down Expand Up @@ -269,7 +277,7 @@ def f(a):
@pytest.mark.asyncio
@asyncio.coroutine
def test_aioredis(aioredis_pool, expire):
client = yield from aioredis_pool
client, _ = aioredis_pool

@ring.aioredis(client, expire=expire)
@asyncio.coroutine
Expand Down Expand Up @@ -343,7 +351,7 @@ def f(a):
@pytest.mark.asyncio
@asyncio.coroutine
def test_func_method(storage_dict):
storage = yield from storage_dict
storage, _ = storage_dict

class A(object):
def __ring_key__(self):
Expand Down Expand Up @@ -377,16 +385,16 @@ def cmethod(cls, a, b):

@pytest.mark.asyncio
@asyncio.coroutine
def test_forced_sync(synchronous_storage):
storage = synchronous_storage
def test_forced_sync(synchronous_storage_and_ring):
storage, storage_ring = synchronous_storage_and_ring

with pytest.raises(TypeError):
@storage.ring(storage)
@storage_ring(storage)
@asyncio.coroutine
def g():
return 1

@storage.ring(storage, force_asyncio=True)
@storage_ring(storage, force_asyncio=True)
@asyncio.coroutine
def f(a):
return a
Expand Down

0 comments on commit 82003ba

Please sign in to comment.