From b19571ddd3ffa9f3c5df2dea94a35156796c8fbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sam=20Minn=C3=A9e?= Date: Wed, 24 Jul 2024 16:34:01 +1200 Subject: [PATCH 1/2] fix: use poetry installs of mypy & black in the pre-commit hook --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21b7e45..aadae04 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: hooks: - id: black name: Format with Black - entry: black + entry: poetry run black language: system types: [python] @@ -36,6 +36,6 @@ repos: - id: mypy name: Validate types with MyPy - entry: mypy + entry: poetry run mypy language: system types: [ python ] From 8a97ced7af38094f46bdc2a2b1c5e3703ede3083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sam=20Minn=C3=A9e?= Date: Wed, 24 Jul 2024 16:34:59 +1200 Subject: [PATCH 2/2] feat: Add get_progress and set_progress to redis result backend Uses as standard suffix on the redis key (hardcoded as "__progress") to store progress results --- taskiq_redis/redis_backend.py | 148 ++++++++++++++++++++++++++++++++++ tests/test_result_backend.py | 122 ++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+) diff --git a/taskiq_redis/redis_backend.py b/taskiq_redis/redis_backend.py index 28ce927..104d8cc 100644 --- a/taskiq_redis/redis_backend.py +++ b/taskiq_redis/redis_backend.py @@ -19,6 +19,7 @@ from taskiq.abc.result_backend import TaskiqResult from taskiq.abc.serializer import TaskiqSerializer from taskiq.compat import model_dump, model_validate +from taskiq.depends.progress_tracker import TaskProgress from taskiq.serializers import PickleSerializer from taskiq_redis.exceptions import ( @@ -41,6 +42,8 @@ _ReturnType = TypeVar("_ReturnType") +PROGRESS_KEY_SUFFIX = "__progress" + class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]): """Async result based on redis.""" @@ -174,6 +177,55 @@ async def get_result( return taskiq_result + async def set_progress( + self, + task_id: str, + progress: TaskProgress[_ReturnType], + ) -> None: + """ + Sets task progress in redis. + + Dumps TaskProgress instance into the bytes and writes + it to redis with a standard suffix on the task_id as the key + + :param task_id: ID of the task. + :param result: task's TaskProgress instance. + """ + redis_set_params: Dict[str, Union[str, int, bytes]] = { + "name": task_id + PROGRESS_KEY_SUFFIX, + "value": self.serializer.dumpb(model_dump(progress)), + } + if self.result_ex_time: + redis_set_params["ex"] = self.result_ex_time + elif self.result_px_time: + redis_set_params["px"] = self.result_px_time + + async with Redis(connection_pool=self.redis_pool) as redis: + await redis.set(**redis_set_params) # type: ignore + + async def get_progress( + self, + task_id: str, + ) -> Union[TaskProgress[_ReturnType], None]: + """ + Gets progress results from the task. + + :param task_id: task's id. + :return: task's TaskProgress instance. + """ + async with Redis(connection_pool=self.redis_pool) as redis: + result_value = await redis.get( + name=task_id + PROGRESS_KEY_SUFFIX, + ) + + if result_value is None: + return None + + return model_validate( + TaskProgress[_ReturnType], + self.serializer.loadb(result_value), + ) + class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]): """Async result backend based on redis cluster.""" @@ -301,6 +353,53 @@ async def get_result( return taskiq_result + async def set_progress( + self, + task_id: str, + progress: TaskProgress[_ReturnType], + ) -> None: + """ + Sets task progress in redis. + + Dumps TaskProgress instance into the bytes and writes + it to redis with a standard suffix on the task_id as the key + + :param task_id: ID of the task. + :param result: task's TaskProgress instance. + """ + redis_set_params: Dict[str, Union[str, int, bytes]] = { + "name": task_id + PROGRESS_KEY_SUFFIX, + "value": self.serializer.dumpb(model_dump(progress)), + } + if self.result_ex_time: + redis_set_params["ex"] = self.result_ex_time + elif self.result_px_time: + redis_set_params["px"] = self.result_px_time + + await self.redis.set(**redis_set_params) # type: ignore + + async def get_progress( + self, + task_id: str, + ) -> Union[TaskProgress[_ReturnType], None]: + """ + Gets progress results from the task. + + :param task_id: task's id. + :return: task's TaskProgress instance. + """ + result_value = await self.redis.get( # type: ignore[attr-defined] + name=task_id + PROGRESS_KEY_SUFFIX, + ) + + if result_value is None: + return None + + return model_validate( + TaskProgress[_ReturnType], + self.serializer.loadb(result_value), + ) + class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]): """Async result based on redis sentinel.""" @@ -439,6 +538,55 @@ async def get_result( return taskiq_result + async def set_progress( + self, + task_id: str, + progress: TaskProgress[_ReturnType], + ) -> None: + """ + Sets task progress in redis. + + Dumps TaskProgress instance into the bytes and writes + it to redis with a standard suffix on the task_id as the key + + :param task_id: ID of the task. + :param result: task's TaskProgress instance. + """ + redis_set_params: Dict[str, Union[str, int, bytes]] = { + "name": task_id + PROGRESS_KEY_SUFFIX, + "value": self.serializer.dumpb(model_dump(progress)), + } + if self.result_ex_time: + redis_set_params["ex"] = self.result_ex_time + elif self.result_px_time: + redis_set_params["px"] = self.result_px_time + + async with self._acquire_master_conn() as redis: + await redis.set(**redis_set_params) # type: ignore + + async def get_progress( + self, + task_id: str, + ) -> Union[TaskProgress[_ReturnType], None]: + """ + Gets progress results from the task. + + :param task_id: task's id. + :return: task's TaskProgress instance. + """ + async with self._acquire_master_conn() as redis: + result_value = await redis.get( + name=task_id + PROGRESS_KEY_SUFFIX, + ) + + if result_value is None: + return None + + return model_validate( + TaskProgress[_ReturnType], + self.serializer.loadb(result_value), + ) + async def shutdown(self) -> None: """Shutdown sentinel connections.""" for sentinel in self.sentinel.sentinels: diff --git a/tests/test_result_backend.py b/tests/test_result_backend.py index 68ed965..cd0a12d 100644 --- a/tests/test_result_backend.py +++ b/tests/test_result_backend.py @@ -4,6 +4,7 @@ import pytest from taskiq import TaskiqResult +from taskiq.depends.progress_tracker import TaskProgress, TaskState from taskiq_redis import ( RedisAsyncClusterResultBackend, @@ -438,3 +439,124 @@ async def test_keep_results_after_reading_sentinel( res2 = await result_backend.get_result(task_id=task_id) assert res1 == res2 await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_set_progress(redis_url: str) -> None: + """ + Test that set_progress/get_progress works. + + :param redis_url: redis URL. + """ + result_backend = RedisAsyncResultBackend( # type: ignore + redis_url=redis_url, + ) + task_id = uuid.uuid4().hex + + test_progress_1 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "quarter way", "pct": 25}, + ) + test_progress_2 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "half way", "pct": 50}, + ) + + # Progress starts as None + assert await result_backend.get_progress(task_id=task_id) is None + + # Setting the first time persists + await result_backend.set_progress(task_id=task_id, progress=test_progress_1) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_1 + + # Setting the second time replaces the first + await result_backend.set_progress(task_id=task_id, progress=test_progress_2) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_2 + + await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_set_progress_cluster(redis_cluster_url: str) -> None: + """ + Test that set_progress/get_progress works in cluster mode. + + :param redis_url: redis URL. + """ + result_backend = RedisAsyncClusterResultBackend( # type: ignore + redis_url=redis_cluster_url, + ) + task_id = uuid.uuid4().hex + + test_progress_1 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "quarter way", "pct": 25}, + ) + test_progress_2 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "half way", "pct": 50}, + ) + + # Progress starts as None + assert await result_backend.get_progress(task_id=task_id) is None + + # Setting the first time persists + await result_backend.set_progress(task_id=task_id, progress=test_progress_1) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_1 + + # Setting the second time replaces the first + await result_backend.set_progress(task_id=task_id, progress=test_progress_2) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_2 + + await result_backend.shutdown() + + +@pytest.mark.anyio +async def test_set_progress_sentinel( + redis_sentinels: List[Tuple[str, int]], + redis_sentinel_master_name: str, +) -> None: + """ + Test that set_progress/get_progress works in cluster mode. + + :param redis_url: redis URL. + """ + result_backend = RedisAsyncSentinelResultBackend( # type: ignore + sentinels=redis_sentinels, + master_name=redis_sentinel_master_name, + ) + task_id = uuid.uuid4().hex + + test_progress_1 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "quarter way", "pct": 25}, + ) + test_progress_2 = TaskProgress( + state=TaskState.STARTED, + meta={"message": "half way", "pct": 50}, + ) + + # Progress starts as None + assert await result_backend.get_progress(task_id=task_id) is None + + # Setting the first time persists + await result_backend.set_progress(task_id=task_id, progress=test_progress_1) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_1 + + # Setting the second time replaces the first + await result_backend.set_progress(task_id=task_id, progress=test_progress_2) + + fetched_result = await result_backend.get_progress(task_id=task_id) + assert fetched_result == test_progress_2 + + await result_backend.shutdown()