Skip to content

Commit

Permalink
feat: Add get_progress and set_progress to redis result backend
Browse files Browse the repository at this point in the history
Uses as standard suffix on the redis key (hardcoded as "__progress") to store progress results
  • Loading branch information
sminnee committed Jul 25, 2024
1 parent 5a66d5b commit c3d7b5a
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 0 deletions.
148 changes: 148 additions & 0 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from taskiq.abc.result_backend import TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.serializers import PickleSerializer

from taskiq_redis.exceptions import (
Expand All @@ -41,6 +42,8 @@

_ReturnType = TypeVar("_ReturnType")

PROGRESS_KEY_SUFFIX = "__progress"


class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result based on redis."""
Expand Down Expand Up @@ -174,6 +177,55 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

async with Redis(connection_pool=self.redis_pool) as redis:
await redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
async with Redis(connection_pool=self.redis_pool) as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)


class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result backend based on redis cluster."""
Expand Down Expand Up @@ -301,6 +353,53 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

await self.redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
result_value = await self.redis.get( # type: ignore[attr-defined]
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)


class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
"""Async result based on redis sentinel."""
Expand Down Expand Up @@ -439,6 +538,55 @@ async def get_result(

return taskiq_result

async def set_progress(
self,
task_id: str,
progress: TaskProgress[_ReturnType],
) -> None:
"""
Sets task progress in redis.
Dumps TaskProgress instance into the bytes and writes
it to redis with a standard suffix on the task_id as the key
:param task_id: ID of the task.
:param result: task's TaskProgress instance.
"""
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id + PROGRESS_KEY_SUFFIX,
"value": self.serializer.dumpb(model_dump(progress)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
elif self.result_px_time:
redis_set_params["px"] = self.result_px_time

async with self._acquire_master_conn() as redis:
await redis.set(**redis_set_params) # type: ignore

async def get_progress(
self,
task_id: str,
) -> Union[TaskProgress[_ReturnType], None]:
"""
Gets progress results from the task.
:param task_id: task's id.
:return: task's TaskProgress instance.
"""
async with self._acquire_master_conn() as redis:
result_value = await redis.get(
name=task_id + PROGRESS_KEY_SUFFIX,
)

if result_value is None:
return None

return model_validate(
TaskProgress[_ReturnType],
self.serializer.loadb(result_value),
)

async def shutdown(self) -> None:
"""Shutdown sentinel connections."""
for sentinel in self.sentinel.sentinels:
Expand Down
122 changes: 122 additions & 0 deletions tests/test_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
from taskiq import TaskiqResult
from taskiq.depends.progress_tracker import TaskProgress, TaskState

from taskiq_redis import (
RedisAsyncClusterResultBackend,
Expand Down Expand Up @@ -438,3 +439,124 @@ async def test_keep_results_after_reading_sentinel(
res2 = await result_backend.get_result(task_id=task_id)
assert res1 == res2
await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress(redis_url: str) -> None:
"""
Test that set_progress/get_progress works.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncResultBackend( # type: ignore
redis_url=redis_url,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress_cluster(redis_cluster_url: str) -> None:
"""
Test that set_progress/get_progress works in cluster mode.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncClusterResultBackend( # type: ignore
redis_url=redis_cluster_url,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()


@pytest.mark.anyio
async def test_set_progress_sentinel(
redis_sentinels: List[Tuple[str, int]],
redis_sentinel_master_name: str,
) -> None:
"""
Test that set_progress/get_progress works in cluster mode.
:param redis_url: redis URL.
"""
result_backend = RedisAsyncSentinelResultBackend( # type: ignore
sentinels=redis_sentinels,
master_name=redis_sentinel_master_name,
)
task_id = uuid.uuid4().hex

test_progress_1 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "quarter way", "pct": 25},
)
test_progress_2 = TaskProgress(
state=TaskState.STARTED,
meta={"message": "half way", "pct": 50},
)

# Progress starts as None
assert await result_backend.get_progress(task_id=task_id) is None

# Setting the first time persists
await result_backend.set_progress(task_id=task_id, progress=test_progress_1)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_1

# Setting the second time replaces the first
await result_backend.set_progress(task_id=task_id, progress=test_progress_2)

fetched_result = await result_backend.get_progress(task_id=task_id)
assert fetched_result == test_progress_2

await result_backend.shutdown()

0 comments on commit c3d7b5a

Please sign in to comment.