diff --git a/corehq/motech/repeaters/tasks.py b/corehq/motech/repeaters/tasks.py index 09793e187c45..b136d082a243 100644 --- a/corehq/motech/repeaters/tasks.py +++ b/corehq/motech/repeaters/tasks.py @@ -311,7 +311,9 @@ def process_repeaters(): continue if rate_limit_repeater(domain, repeater_id): continue - lock_and_process_repeater_id(repeater_id) + lock = RepeaterLock(repeater_id) + if lock.acquire(): + process_repeater(lock.repeater, lock.token) def iter_ready_repeater_ids(): @@ -341,36 +343,6 @@ def iter_ready_repeater_ids(): yield domain, repeater_id -def lock_and_process_repeater_id(repeater_id): - lock = get_repeater_lock(repeater_id) - # Generate a lock token using `uuid1()` the same way that - # `redis.lock.Lock` does. The `Lock` class uses the token to - # determine ownership, so that one process can acquire a - # lock and a different process can release it. This lock - # will be released by the `update_repeater()` task. - lock_token = uuid.uuid1().hex - if lock.acquire(blocking=False, token=lock_token): - repeater = Repeater.objects.get(id=repeater_id) - process_repeater(repeater, lock_token) - - -def relock_and_process_repeater(repeater, lock_token): - lock = get_repeater_lock(repeater.repeater_id, lock_token) - # Reset the lock timeout - # https://github.com/redis/redis-py/blob/ff120df78ccd85d6e2e2938ee02d1eb831676724/redis/lock.py#L235 - lock.reacquire() - process_repeater(repeater, lock_token) - - -def get_repeater_lock(repeater_id, lock_token=None): - name = f'process_repeater_{repeater_id}' - half_an_hour = 30 * 60 - lock = get_redis_lock(key=name, name=name, timeout=half_an_hour) - if lock_token: - lock.local.token = lock_token - return lock - - def get_repeater_ids_by_domain(): repeater_ids_by_domain = Repeater.objects.get_all_ready_ids_by_domain() always_enabled_domains = set(toggles.PROCESS_REPEATERS.get_enabled_domains()) @@ -521,13 +493,65 @@ def update_repeater(repeat_record_states, repeater_id, lock_token, more): ) repeater.set_backoff() finally: + lock = RepeaterLock(repeater, lock_token) if more: - relock_and_process_repeater(repeater, lock_token) + lock.reacquire() + process_repeater(repeater, lock_token) else: - lock = get_repeater_lock(repeater_id, lock_token) lock.release() +class RepeaterLock: + """ + A utility class for encapsulating lock-related logic for a repeater. + """ + + timeout = 30 * 60 # Half an hour + + def __init__(self, repeater, lock_token=None): + if isinstance(repeater, Repeater): + self.repeater_id = repeater.repeater_id + self._repeater = repeater + else: + self.repeater_id = repeater + self._repeater = None + self.token = lock_token + self._lock = self._get_lock() + + @property + def repeater(self): + if self._repeater is None: + self._repeater = Repeater.objects.get(id=self.repeater_id) + return self._repeater + + def acquire(self): + assert self.token is None, 'You have already acquired this lock' + # Generate a lock token using `uuid1()` the same way that + # `redis.lock.Lock` does. The `Lock` class uses the token to + # determine ownership, so that one process can acquire a + # lock and a different process can release it. This lock + # will be released by the `update_repeater()` task. + self.token = uuid.uuid1().hex + return self._lock.acquire(blocking=False, token=self.token) + + def reacquire(self): + assert self.token, 'Missing lock token' + # Reset the lock timeout + # https://github.com/redis/redis-py/blob/ff120df78ccd85d6e2e2938ee02d1eb831676724/redis/lock.py#L235 + return self._lock.reacquire() + + def release(self): + assert self.token, 'Missing lock token' + return self._lock.release() + + def _get_lock(self): + name = f'process_repeater_{self.repeater_id}' + lock = get_redis_lock(key=name, name=name, timeout=self.timeout) + if self.token: + lock.local.token = self.token + return lock + + metrics_gauge_task( 'commcare.repeaters.overdue', RepeatRecord.objects.count_overdue, diff --git a/corehq/motech/repeaters/tests/test_tasks.py b/corehq/motech/repeaters/tests/test_tasks.py index b18113a0b5cf..66def3c59860 100644 --- a/corehq/motech/repeaters/tests/test_tasks.py +++ b/corehq/motech/repeaters/tests/test_tasks.py @@ -4,6 +4,7 @@ from django.test import SimpleTestCase, TestCase +import pytest from freezegun import freeze_time from corehq.motech.models import ConnectionSettings, RequestLog @@ -12,6 +13,7 @@ from ..const import State from ..models import FormRepeater, Repeater, RepeatRecord from ..tasks import ( + RepeaterLock, _get_wait_duration_seconds, _process_repeat_record, delete_old_request_logs, @@ -290,7 +292,7 @@ def test_get_repeater_ids_by_domain(): @flag_enabled('PROCESS_REPEATERS') class TestUpdateRepeater(SimpleTestCase): - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_resets_backoff_on_success(self, mock_get_repeater, __): repeat_record_states = [State.Success, State.Fail, State.Empty, None] @@ -301,7 +303,7 @@ def test_update_repeater_resets_backoff_on_success(self, mock_get_repeater, __): mock_repeater.set_backoff.assert_not_called() mock_repeater.reset_backoff.assert_called_once() - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_resets_backoff_on_invalid(self, mock_get_repeater, __): repeat_record_states = [State.InvalidPayload, State.Fail, State.Empty, None] @@ -312,7 +314,7 @@ def test_update_repeater_resets_backoff_on_invalid(self, mock_get_repeater, __): mock_repeater.set_backoff.assert_not_called() mock_repeater.reset_backoff.assert_called_once() - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_sets_backoff_on_failure(self, mock_get_repeater, __): repeat_record_states = [State.Fail, State.Empty, None] @@ -323,7 +325,7 @@ def test_update_repeater_sets_backoff_on_failure(self, mock_get_repeater, __): mock_repeater.set_backoff.assert_called_once() mock_repeater.reset_backoff.assert_not_called() - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_does_nothing_on_empty(self, mock_get_repeater, __): repeat_record_states = [State.Empty] @@ -334,7 +336,7 @@ def test_update_repeater_does_nothing_on_empty(self, mock_get_repeater, __): mock_repeater.set_backoff.assert_not_called() mock_repeater.reset_backoff.assert_not_called() - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_does_nothing_on_none(self, mock_get_repeater, __): repeat_record_states = [None] @@ -346,7 +348,7 @@ def test_update_repeater_does_nothing_on_none(self, mock_get_repeater, __): mock_repeater.reset_backoff.assert_not_called() @patch('corehq.motech.repeaters.tasks.process_repeater') - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_calls_process_repeater_on_more( self, @@ -363,7 +365,7 @@ def test_update_repeater_calls_process_repeater_on_more( mock_process_repeater.assert_called_once_with(mock_repeater, 'token') - @patch('corehq.motech.repeaters.tasks.get_repeater_lock') + @patch('corehq.motech.repeaters.tasks.RepeaterLock') @patch('corehq.motech.repeaters.tasks.Repeater.objects.get') def test_update_repeater_releases_lock_on_no_more( self, @@ -441,3 +443,46 @@ def test_repeat_record_two_attempts(self): ) wait_duration = _get_wait_duration_seconds(repeat_record) self.assertEqual(wait_duration, 5) + + +class TestRepeaterLock(TestCase): + + def test_lock_repeater(self): + repeater = self._get_repeater() + lock = RepeaterLock(repeater.repeater_id) + assert lock.repeater == repeater + + def test_lock_name(self): + lock = RepeaterLock('abc123') + self.assertEqual(lock._lock.name, 'process_repeater_abc123') + + def test_acquire(self): + repeater = self._get_repeater() + lock = RepeaterLock(repeater) + assert lock.acquire() + assert lock.token + + def test_acquire_assert(self): + lock = RepeaterLock('repeater_id', 'lock_token') + with pytest.raises(AssertionError, match=r'.* already acquired .*'): + lock.acquire() + + def test_reacquire_assert(self): + lock = RepeaterLock('repeater_id') + with pytest.raises(AssertionError, match=r'Missing lock token'): + lock.reacquire() + + def test_release_assert(self): + lock = RepeaterLock('repeater_id') + with pytest.raises(AssertionError, match=r'Missing lock token'): + lock.release() + + @staticmethod + def _get_repeater(): + return FormRepeater.objects.create( + domain=DOMAIN, + connection_settings=ConnectionSettings.objects.create( + domain=DOMAIN, + url='http://www.example.com/api/' + ), + ) diff --git a/corehq/tests/pytest_plugins/redislocks.py b/corehq/tests/pytest_plugins/redislocks.py index f614368ed512..dc95f7af742b 100644 --- a/corehq/tests/pytest_plugins/redislocks.py +++ b/corehq/tests/pytest_plugins/redislocks.py @@ -1,6 +1,7 @@ """A plugin that causes blocking redis locks to error on lock timeout""" import logging from datetime import datetime +from unittest.mock import Mock import attr import pytest @@ -48,6 +49,8 @@ class TestLock: lock = attr.ib(repr=False) timeout = attr.ib() + local = Mock() + def acquire(self, **kw): start = datetime.now() log.info("acquire %s", self)