Skip to content

Commit

Permalink
Merge pull request #54 from stinovlas/use-blocking-connection-pool
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Feb 12, 2024
2 parents 09e835e + 7cfb6f7 commit 0489e6c
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 10 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ Brokers parameters:
* `result_backend` - custom result backend.
* `queue_name` - name of the pub/sub channel in redis.
* `max_connection_pool_size` - maximum number of connections in pool.
* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`.
Notably, you can use `timeout` to set custom timeout in seconds for reconnects
(or set it to `None` to try reconnects indefinitely).

## RedisAsyncResultBackend configuration

Expand All @@ -79,6 +82,9 @@ RedisAsyncResultBackend parameters:
* `keep_results` - flag to not remove results from Redis after reading.
* `result_ex_time` - expire time in seconds (by default - not specified)
* `result_px_time` - expire time in milliseconds (by default - not specified)
* Any other keyword arguments are passed to `redis.asyncio.BlockingConnectionPool`.
Notably, you can use `timeout` to set custom timeout in seconds for reconnects
(or set it to `None` to try reconnects indefinitely).
> IMPORTANT: **It is highly recommended to use expire time ​​in RedisAsyncResultBackend**
> If you want to add expiration, either `result_ex_time` or `result_px_time` must be set.
>```python
Expand Down
21 changes: 17 additions & 4 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pickle
from typing import Dict, Optional, TypeVar, Union
from typing import Any, Dict, Optional, TypeVar, Union

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio import BlockingConnectionPool, Redis
from redis.asyncio.cluster import RedisCluster
from taskiq import AsyncResultBackend
from taskiq.abc.result_backend import TaskiqResult
Expand All @@ -24,6 +24,8 @@ def __init__(
keep_results: bool = True,
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
max_connection_pool_size: Optional[int] = None,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new result backend.
Expand All @@ -32,13 +34,19 @@ def __init__(
:param keep_results: flag to not remove results from Redis after reading.
:param result_ex_time: expire time in seconds for result.
:param result_px_time: expire time in milliseconds for result.
:param max_connection_pool_size: maximum number of connections in pool.
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
:raises DuplicateExpireTimeSelectedError: if result_ex_time
and result_px_time are selected.
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
and result_px_time are equal zero.
"""
self.redis_pool = ConnectionPool.from_url(redis_url)
self.redis_pool = BlockingConnectionPool.from_url(
url=redis_url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -146,6 +154,7 @@ def __init__(
keep_results: bool = True,
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new result backend.
Expand All @@ -154,13 +163,17 @@ def __init__(
:param keep_results: flag to not remove results from Redis after reading.
:param result_ex_time: expire time in seconds for result.
:param result_px_time: expire time in milliseconds for result.
:param connection_kwargs: additional arguments for RedisCluster.
:raises DuplicateExpireTimeSelectedError: if result_ex_time
and result_px_time are selected.
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
and result_px_time are equal zero.
"""
self.redis: RedisCluster[bytes] = RedisCluster.from_url(redis_url)
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
redis_url,
**connection_kwargs,
)
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down
8 changes: 5 additions & 3 deletions taskiq_redis/redis_broker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from logging import getLogger
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage
Expand Down Expand Up @@ -31,14 +31,16 @@ def __init__(
:param result_backend: custom result backend.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
Each worker opens its own connection. Therefore this value has to be
at least number of workers + 1.
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
"""
super().__init__(
result_backend=result_backend,
task_id_generator=task_id_generator,
)

self.connection_pool: ConnectionPool = ConnectionPool.from_url(
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down
6 changes: 3 additions & 3 deletions taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from redis.asyncio import ConnectionPool, Redis, RedisCluster
from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis, RedisCluster
from taskiq import ScheduleSource
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
Expand All @@ -22,7 +22,7 @@ class RedisScheduleSource(ScheduleSource):
This is how many keys will be fetched at once.
:param max_connection_pool_size: maximum number of connections in pool.
:param serializer: serializer for data.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
:param connection_kwargs: additional arguments for redis BlockingConnectionPool.
"""

def __init__(
Expand All @@ -35,7 +35,7 @@ def __init__(
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.connection_pool: ConnectionPool = ConnectionPool.from_url(
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down
46 changes: 46 additions & 0 deletions tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,29 @@ async def test_pub_sub_broker(
await broker.shutdown()


@pytest.mark.anyio
async def test_pub_sub_broker_max_connections(
valid_broker_message: BrokerMessage,
redis_url: str,
) -> None:
"""Test PubSubBroker with connection limit set."""
broker = PubSubBroker(
url=redis_url,
queue_name=uuid.uuid4().hex,
max_connection_pool_size=4,
timeout=1,
)
worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)]
await asyncio.sleep(0.3)

await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)])
await asyncio.sleep(0.3)

for worker in worker_tasks:
worker.cancel()
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_broker(
valid_broker_message: BrokerMessage,
Expand Down Expand Up @@ -98,6 +121,29 @@ async def test_list_queue_broker(
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_broker_max_connections(
valid_broker_message: BrokerMessage,
redis_url: str,
) -> None:
"""Test ListQueueBroker with connection limit set."""
broker = ListQueueBroker(
url=redis_url,
queue_name=uuid.uuid4().hex,
max_connection_pool_size=4,
timeout=1,
)
worker_tasks = [asyncio.create_task(get_message(broker)) for _ in range(3)]
await asyncio.sleep(0.3)

await asyncio.gather(*[broker.kick(valid_broker_message) for _ in range(50)])
await asyncio.sleep(0.3)

for worker in worker_tasks:
worker.cancel()
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_cluster_broker(
valid_broker_message: BrokerMessage,
Expand Down
33 changes: 33 additions & 0 deletions tests/test_result_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import uuid

import pytest
Expand Down Expand Up @@ -132,6 +133,38 @@ async def test_keep_results_after_reading(redis_url: str) -> None:
await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_result_max_connections(redis_url: str) -> None:
"""
Tests that asynchronous backend works with connection limit.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncResultBackend( # type: ignore
redis_url=redis_url,
max_connection_pool_size=1,
timeout=3,
)

task_id = uuid.uuid4().hex
result: "TaskiqResult[int]" = TaskiqResult(
is_err=True,
log="My Log",
return_value=11,
execution_time=112.2,
)
await result_backend.set_result(
task_id=task_id,
result=result,
)

async def get_result() -> None:
await result_backend.get_result(task_id=task_id, with_logs=True)

await asyncio.gather(*[get_result() for _ in range(10)])
await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_result_success_cluster(redis_cluster_url: str) -> None:
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import datetime as dt
import uuid

Expand Down Expand Up @@ -108,6 +109,18 @@ async def test_buffer(redis_url: str) -> None:
await source.shutdown()


@pytest.mark.anyio
async def test_max_connections(redis_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisScheduleSource(
redis_url,
prefix=prefix,
max_connection_pool_size=1,
timeout=3,
)
await asyncio.gather(*[source.get_schedules() for _ in range(10)])


@pytest.mark.anyio
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
Expand Down

0 comments on commit 0489e6c

Please sign in to comment.