Skip to content

Commit

Permalink
Encapsulate locking in a class
Browse files Browse the repository at this point in the history
  • Loading branch information
kaapstorm committed Jan 22, 2025
1 parent ebec20d commit 09d757e
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 40 deletions.
90 changes: 57 additions & 33 deletions corehq/motech/repeaters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ def process_repeaters():
continue
if rate_limit_repeater(domain, repeater_id):
continue
lock_and_process_repeater_id(repeater_id)
lock = RepeaterLock(repeater_id)
if lock.acquire():
process_repeater(lock.repeater, lock.token)


def iter_ready_repeater_ids():
Expand Down Expand Up @@ -341,36 +343,6 @@ def iter_ready_repeater_ids():
yield domain, repeater_id


def lock_and_process_repeater_id(repeater_id):
lock = get_repeater_lock(repeater_id)
# Generate a lock token using `uuid1()` the same way that
# `redis.lock.Lock` does. The `Lock` class uses the token to
# determine ownership, so that one process can acquire a
# lock and a different process can release it. This lock
# will be released by the `update_repeater()` task.
lock_token = uuid.uuid1().hex
if lock.acquire(blocking=False, token=lock_token):
repeater = Repeater.objects.get(id=repeater_id)
process_repeater(repeater, lock_token)


def relock_and_process_repeater(repeater, lock_token):
lock = get_repeater_lock(repeater.repeater_id, lock_token)
# Reset the lock timeout
# https://github.com/redis/redis-py/blob/ff120df78ccd85d6e2e2938ee02d1eb831676724/redis/lock.py#L235
lock.reacquire()
process_repeater(repeater, lock_token)


def get_repeater_lock(repeater_id, lock_token=None):
name = f'process_repeater_{repeater_id}'
half_an_hour = 30 * 60
lock = get_redis_lock(key=name, name=name, timeout=half_an_hour)
if lock_token:
lock.local.token = lock_token
return lock


def get_repeater_ids_by_domain():
repeater_ids_by_domain = Repeater.objects.get_all_ready_ids_by_domain()
always_enabled_domains = set(toggles.PROCESS_REPEATERS.get_enabled_domains())
Expand Down Expand Up @@ -521,13 +493,65 @@ def update_repeater(repeat_record_states, repeater_id, lock_token, more):
)
repeater.set_backoff()
finally:
lock = RepeaterLock(repeater, lock_token)
if more:
relock_and_process_repeater(repeater, lock_token)
lock.reacquire()
process_repeater(repeater, lock_token)
else:
lock = get_repeater_lock(repeater_id, lock_token)
lock.release()


class RepeaterLock:
"""
A utility class for encapsulating lock-related logic for a repeater.
"""

timeout = 30 * 60 # Half an hour

def __init__(self, repeater, lock_token=None):
if isinstance(repeater, Repeater):
self.repeater_id = repeater.repeater_id
self._repeater = repeater
else:
self.repeater_id = repeater
self._repeater = None
self.token = lock_token
self._lock = self._get_lock()

@property
def repeater(self):
if self._repeater is None:
self._repeater = Repeater.objects.get(id=self.repeater_id)
return self._repeater

def acquire(self):
assert self.token is None, 'You have already acquired this lock'
# Generate a lock token using `uuid1()` the same way that
# `redis.lock.Lock` does. The `Lock` class uses the token to
# determine ownership, so that one process can acquire a
# lock and a different process can release it. This lock
# will be released by the `update_repeater()` task.
self.token = uuid.uuid1().hex
return self._lock.acquire(blocking=False, token=self.token)

def reacquire(self):
assert self.token, 'Missing lock token'
# Reset the lock timeout
# https://github.com/redis/redis-py/blob/ff120df78ccd85d6e2e2938ee02d1eb831676724/redis/lock.py#L235
return self._lock.reacquire()

def release(self):
assert self.token, 'Missing lock token'
return self._lock.release()

def _get_lock(self):
name = f'process_repeater_{self.repeater_id}'
lock = get_redis_lock(key=name, name=name, timeout=self.timeout)
if self.token:
lock.local.token = self.token
return lock


metrics_gauge_task(
'commcare.repeaters.overdue',
RepeatRecord.objects.count_overdue,
Expand Down
59 changes: 52 additions & 7 deletions corehq/motech/repeaters/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from django.test import SimpleTestCase, TestCase

import pytest
from freezegun import freeze_time

from corehq.motech.models import ConnectionSettings, RequestLog
Expand All @@ -12,6 +13,7 @@
from ..const import State
from ..models import FormRepeater, Repeater, RepeatRecord
from ..tasks import (
RepeaterLock,
_get_wait_duration_seconds,
_process_repeat_record,
delete_old_request_logs,
Expand Down Expand Up @@ -290,7 +292,7 @@ def test_get_repeater_ids_by_domain():
@flag_enabled('PROCESS_REPEATERS')
class TestUpdateRepeater(SimpleTestCase):

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_resets_backoff_on_success(self, mock_get_repeater, __):
repeat_record_states = [State.Success, State.Fail, State.Empty, None]
Expand All @@ -301,7 +303,7 @@ def test_update_repeater_resets_backoff_on_success(self, mock_get_repeater, __):
mock_repeater.set_backoff.assert_not_called()
mock_repeater.reset_backoff.assert_called_once()

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_resets_backoff_on_invalid(self, mock_get_repeater, __):
repeat_record_states = [State.InvalidPayload, State.Fail, State.Empty, None]
Expand All @@ -312,7 +314,7 @@ def test_update_repeater_resets_backoff_on_invalid(self, mock_get_repeater, __):
mock_repeater.set_backoff.assert_not_called()
mock_repeater.reset_backoff.assert_called_once()

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_sets_backoff_on_failure(self, mock_get_repeater, __):
repeat_record_states = [State.Fail, State.Empty, None]
Expand All @@ -323,7 +325,7 @@ def test_update_repeater_sets_backoff_on_failure(self, mock_get_repeater, __):
mock_repeater.set_backoff.assert_called_once()
mock_repeater.reset_backoff.assert_not_called()

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_does_nothing_on_empty(self, mock_get_repeater, __):
repeat_record_states = [State.Empty]
Expand All @@ -334,7 +336,7 @@ def test_update_repeater_does_nothing_on_empty(self, mock_get_repeater, __):
mock_repeater.set_backoff.assert_not_called()
mock_repeater.reset_backoff.assert_not_called()

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_does_nothing_on_none(self, mock_get_repeater, __):
repeat_record_states = [None]
Expand All @@ -346,7 +348,7 @@ def test_update_repeater_does_nothing_on_none(self, mock_get_repeater, __):
mock_repeater.reset_backoff.assert_not_called()

@patch('corehq.motech.repeaters.tasks.process_repeater')
@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_calls_process_repeater_on_more(
self,
Expand All @@ -363,7 +365,7 @@ def test_update_repeater_calls_process_repeater_on_more(

mock_process_repeater.assert_called_once_with(mock_repeater, 'token')

@patch('corehq.motech.repeaters.tasks.get_repeater_lock')
@patch('corehq.motech.repeaters.tasks.RepeaterLock')
@patch('corehq.motech.repeaters.tasks.Repeater.objects.get')
def test_update_repeater_releases_lock_on_no_more(
self,
Expand Down Expand Up @@ -441,3 +443,46 @@ def test_repeat_record_two_attempts(self):
)
wait_duration = _get_wait_duration_seconds(repeat_record)
self.assertEqual(wait_duration, 5)


class TestRepeaterLock(TestCase):

def test_lock_repeater(self):
repeater = self._get_repeater()
lock = RepeaterLock(repeater.repeater_id)
assert lock.repeater == repeater

def test_lock_name(self):
lock = RepeaterLock('abc123')
self.assertEqual(lock._lock.name, 'process_repeater_abc123')

def test_acquire(self):
repeater = self._get_repeater()
lock = RepeaterLock(repeater)
assert lock.acquire()
assert lock.token

def test_acquire_assert(self):
lock = RepeaterLock('repeater_id', 'lock_token')
with pytest.raises(AssertionError, match=r'.* already acquired .*'):
lock.acquire()

def test_reacquire_assert(self):
lock = RepeaterLock('repeater_id')
with pytest.raises(AssertionError, match=r'Missing lock token'):
lock.reacquire()

def test_release_assert(self):
lock = RepeaterLock('repeater_id')
with pytest.raises(AssertionError, match=r'Missing lock token'):
lock.release()

@staticmethod
def _get_repeater():
return FormRepeater.objects.create(
domain=DOMAIN,
connection_settings=ConnectionSettings.objects.create(
domain=DOMAIN,
url='http://www.example.com/api/'
),
)
3 changes: 3 additions & 0 deletions corehq/tests/pytest_plugins/redislocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A plugin that causes blocking redis locks to error on lock timeout"""
import logging
from datetime import datetime
from unittest.mock import Mock

import attr
import pytest
Expand Down Expand Up @@ -48,6 +49,8 @@ class TestLock:
lock = attr.ib(repr=False)
timeout = attr.ib()

local = Mock()

def acquire(self, **kw):
start = datetime.now()
log.info("acquire %s", self)
Expand Down

0 comments on commit 09d757e

Please sign in to comment.