Skip to content

Commit

Permalink
Merge branch 'release/0.5.3'
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed Nov 13, 2023
2 parents 039bd53 + 6a29f1b commit c10881f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "taskiq-redis"
version = "0.5.2"
version = "0.5.3"
description = "Redis integration for taskiq"
authors = ["taskiq-team <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 2 additions & 0 deletions taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
RedisAsyncResultBackend,
)
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
from taskiq_redis.schedule_source import RedisScheduleSource

__all__ = [
"RedisAsyncClusterResultBackend",
"RedisAsyncResultBackend",
"ListQueueBroker",
"PubSubBroker",
"ListQueueClusterBroker",
"RedisScheduleSource",
]
67 changes: 67 additions & 0 deletions taskiq_redis/redis_cluster_broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, AsyncGenerator

from redis.asyncio import RedisCluster
from taskiq.abc.broker import AsyncBroker
from taskiq.message import BrokerMessage


class BaseRedisClusterBroker(AsyncBroker):
"""Base broker that works with Redis Cluster."""

def __init__(
self,
url: str,
queue_name: str = "taskiq",
max_connection_pool_size: int = 2**31,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new broker.
:param url: url to redis.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
"""
super().__init__()

self.redis: RedisCluster[bytes] = RedisCluster.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)

self.queue_name = queue_name

async def shutdown(self) -> None:
"""Closes redis connection pool."""
await self.redis.aclose() # type: ignore[attr-defined]
await super().shutdown()


class ListQueueClusterBroker(BaseRedisClusterBroker):
"""Broker that works with Redis Cluster and distributes tasks between workers."""

async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.
This method appends a message to the list of all messages.
:param message: message to append.
"""
await self.redis.lpush(self.queue_name, message.message) # type: ignore[attr-defined]

async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.
This function listens to the queue
and yields new messages if they have BrokerMessage type.
:yields: broker messages.
"""
redis_brpop_data_position = 1
while True:
value = await self.redis.brpop([self.queue_name]) # type: ignore[attr-defined]
yield value[redis_brpop_data_position]
29 changes: 28 additions & 1 deletion tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from taskiq import AckableMessage, AsyncBroker, BrokerMessage

from taskiq_redis import ListQueueBroker, PubSubBroker
from taskiq_redis import ListQueueBroker, ListQueueClusterBroker, PubSubBroker


def test_no_url_should_raise_typeerror() -> None:
Expand Down Expand Up @@ -96,3 +96,30 @@ async def test_list_queue_broker(
worker1_task.cancel()
worker2_task.cancel()
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_cluster_broker(
valid_broker_message: BrokerMessage,
redis_cluster_url: str,
) -> None:
"""
Test that messages are published and read correctly by ListQueueClusterBroker.
We create two workers that listen and send a message to them.
Expect only one worker to receive the same message we sent.
"""
broker = ListQueueClusterBroker(
url=redis_cluster_url,
queue_name=uuid.uuid4().hex,
)
worker_task = asyncio.create_task(get_message(broker))
await asyncio.sleep(0.3)

await broker.kick(valid_broker_message)
await asyncio.sleep(0.3)

assert worker_task.done()
assert worker_task.result() == valid_broker_message.message
worker_task.cancel()
await broker.shutdown()

0 comments on commit c10881f

Please sign in to comment.