Skip to content

Commit

Permalink
Add QueueServicePool
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Jan 4, 2024
1 parent bdb401d commit f1bf435
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 11 deletions.
32 changes: 30 additions & 2 deletions safe_transaction_service/events/services/queue_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
from functools import lru_cache
from typing import Any, Dict, List, Optional

from django.conf import settings
Expand All @@ -14,7 +13,6 @@
logger = logging.getLogger(__name__)


@lru_cache()
def getQueueService():
if settings.EVENTS_QUEUE_URL:
return SyncQueueService()
Expand All @@ -24,6 +22,36 @@ def getQueueService():
logger.warning("MockedQueueService is used")


class QueueServicePool:
"""
Context manager to get a QueueService connection from the pool or create a new one and append it to the pool if all the
instances are taken. Very useful for gevent, as it is not safe to share one Pika connection across threads.
https://pika.readthedocs.io/en/stable/faq.html
Use:
```
with QueueServicePool() as queue_service:
queue_service...
```
"""

queue_service_pool = []

def __init__(self):
self.instance: QueueService

def __enter__(self):
if self.queue_service_pool:
# If there are elements on the pool, take them
self.instance = self.queue_service_pool.pop()
else:
# If not, get a new client
self.instance = getQueueService()
return self.instance

def __exit__(self, exc_type, exc_val, exc_tb):
self.queue_service_pool.append(self.instance)


class QueueService:
def __init__(self):
self.exchange_name: str = settings.EVENTS_QUEUE_EXCHANGE_NAME
Expand Down
24 changes: 21 additions & 3 deletions safe_transaction_service/events/tests/test_queue_service.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import json
from unittest import mock
from unittest.mock import MagicMock

from django.test import TestCase

from pika.channel import Channel
from pika.exceptions import ConnectionClosedByBroker

from safe_transaction_service.events.services.queue_service import getQueueService
from safe_transaction_service.events.services.queue_service import (
QueueServicePool,
getQueueService,
)


class TestQueueService(TestCase):
def setUp(self):
self.queue_service = getQueueService()
# Ensure that is singleton
self.assertEqual(self.queue_service, getQueueService())
# Create queue for test
self.queue = "test_queue"
self.queue_service._channel.queue_declare(self.queue)
Expand Down Expand Up @@ -57,3 +59,19 @@ def test_send_event_to_queue(self):
# Check if message was written to the queue
_, _, body = self.queue_service._channel.basic_get(self.queue, auto_ack=True)
self.assertEquals(json.loads(body), payload)

@mock.patch(
"safe_transaction_service.events.services.queue_service.getQueueService"
)
def test_queue_service_pool(self, mock_get_queue_service: MagicMock):
queue_service = getQueueService()
QueueServicePool.queue_service_pool = [queue_service]
with QueueServicePool() as queue_service:
self.assertEqual(queue_service, queue_service)
mock_get_queue_service.assert_not_called()

QueueServicePool.queue_service_pool = []
mock_get_queue_service.return_value = queue_service
with QueueServicePool() as queue_service:
self.assertEqual(queue_service, queue_service)
mock_get_queue_service.assert_called_once()
6 changes: 3 additions & 3 deletions safe_transaction_service/history/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from safe_transaction_service.notifications.tasks import send_notification_task

from ..events.services.queue_service import getQueueService
from ..events.services.queue_service import QueueServicePool
from .models import (
ERC20Transfer,
ERC721Transfer,
Expand Down Expand Up @@ -168,8 +168,8 @@ def process_webhook(
countdown=5,
priority=2, # Almost lowest priority
)
queue_service = getQueueService()
queue_service.send_event(payload)
with QueueServicePool() as queue_service:
queue_service.send_event(payload)
else:
logger.debug(
"Notification will not be sent for created=%s object=%s",
Expand Down
6 changes: 3 additions & 3 deletions safe_transaction_service/safe_messages/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db.models.signals import post_save
from django.dispatch import receiver

from safe_transaction_service.events.services.queue_service import getQueueService
from safe_transaction_service.events.services.queue_service import QueueServicePool
from safe_transaction_service.history.services.webhooks import build_webhook_payload
from safe_transaction_service.history.tasks import send_webhook_task
from safe_transaction_service.safe_messages.models import (
Expand Down Expand Up @@ -46,8 +46,8 @@ def process_webhook(
send_webhook_task.apply_async(
args=(address, payload), priority=2 # Almost lowest priority
) # Almost the lowest priority
queue_service = getQueueService()
queue_service.send_event(payload)
with QueueServicePool() as queue_service:
queue_service.send_event(payload)
else:
logger.debug(
"Notification will not be sent for created=%s object=%s",
Expand Down

0 comments on commit f1bf435

Please sign in to comment.