Skip to content

Commit

Permalink
Update with changes from core repo (#6)
Browse files Browse the repository at this point in the history
* feat(task-processor): add Task Processor inputs as env vars (#3355)

* fix: prevent tasks dying from temporary loss of db connection (#3674)

* fix(task-processor): catch all exceptions (#3737)

* infra: run influxdb feature evaluation in thread (#4125)

* fix: Stale connections after task processor errors (#4179)

* Formatting fix

* chore(task-processor): add better logging for failed tasks (#4186)
  • Loading branch information
matthewelwell authored Jun 19, 2024
1 parent d3018fa commit 5f43c67
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 29 deletions.
3 changes: 2 additions & 1 deletion task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def delay(
def run_in_thread(
self,
*,
args: tuple[typing.Any] = (),
args: tuple[typing.Any, ...] = (),
kwargs: dict[str, typing.Any] | None = None,
) -> None:
kwargs = kwargs or {}
_validate_inputs(*args, **kwargs)
thread = Thread(target=self.unwrapped, args=args, kwargs=kwargs, daemon=True)

Expand Down
5 changes: 5 additions & 0 deletions task_processor/management/commands/runprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def handle(self, *args, **options):
grace_period_ms = options["graceperiodms"]
queue_pop_size = options["queuepopsize"]

logger.debug(
"Running task processor with args: %s",
",".join([f"{k}={v}" for k, v in options.items()]),
)

self._threads.extend(
[
TaskRunner(
Expand Down
10 changes: 9 additions & 1 deletion task_processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,15 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run.finished_at = timezone.now()
task.mark_success()
except Exception as e:
logger.warning(e)
logger.warning(
"Failed to execute task '%s'. Exception was: %s",
task.task_identifier,
str(e),
exc_info=True,
)
logger.debug("args: %s", str(task.args))
logger.debug("kwargs: %s", str(task.kwargs))

task.mark_failure()

task_run.result = TaskResult.FAILURE
Expand Down
21 changes: 16 additions & 5 deletions task_processor/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from threading import Thread

from django.db import close_old_connections
from django.utils import timezone

from task_processor.processor import run_recurring_tasks, run_tasks
Expand All @@ -27,12 +28,22 @@ def __init__(
def run(self) -> None:
while not self._stopped:
self.last_checked_for_tasks = timezone.now()
try:
run_tasks(self.queue_pop_size)
except Exception as e:
logger.exception(e)
run_recurring_tasks(self.queue_pop_size)
self.run_iteration()
time.sleep(self.sleep_interval_millis / 1000)

def run_iteration(self) -> None:
try:
run_tasks(self.queue_pop_size)
run_recurring_tasks(self.queue_pop_size)
except Exception as e:
# To prevent task threads from dying if they get an error retrieving the tasks from the
# database this will allow the thread to continue trying to retrieve tasks if it can
# successfully re-establish a connection to the database.
# TODO: is this also what is causing tasks to get stuck as locked? Can we unlock
# tasks here?

logger.error("Received error retrieving tasks: %s.", e, exc_info=e)
close_old_connections()

def stop(self):
self._stopped = True
28 changes: 28 additions & 0 deletions tests/unit/task_processor/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import logging
import typing

import pytest


@pytest.fixture
def run_by_processor(monkeypatch):
monkeypatch.setenv("RUN_BY_PROCESSOR", "True")


class GetTaskProcessorCaplog(typing.Protocol):
def __call__(self, log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
...


@pytest.fixture
def get_task_processor_caplog(
caplog: pytest.LogCaptureFixture,
) -> GetTaskProcessorCaplog:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.

def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(log_level)
caplog.set_level(log_level)
return caplog

return _inner
25 changes: 7 additions & 18 deletions tests/unit/task_processor/test_unit_task_processor_decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import logging
from datetime import timedelta
from unittest.mock import MagicMock

Expand All @@ -16,19 +15,7 @@
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import get_task
from task_processor.task_run_method import TaskRunMethod


@pytest.fixture
def capture_task_processor_logger(caplog: pytest.LogCaptureFixture) -> None:
# caplog doesn't allow you to capture logging outputs from loggers that don't
# propagate to root. Quick hack here to get the task_processor logger to
# propagate.
# TODO: look into using loguru.
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
# Assume required level for the logger.
task_processor_logger.setLevel(logging.INFO)
caplog.set_level(logging.INFO)
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


@pytest.fixture
Expand All @@ -44,11 +31,12 @@ def mock_thread_class(

@pytest.mark.django_db
def test_register_task_handler_run_in_thread__transaction_commit__true__default(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler()
def my_function(*args: str, **kwargs: str) -> None:
pass
Expand Down Expand Up @@ -77,11 +65,12 @@ def my_function(*args: str, **kwargs: str) -> None:


def test_register_task_handler_run_in_thread__transaction_commit__false(
capture_task_processor_logger: None,
caplog: pytest.LogCaptureFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
mock_thread_class: MagicMock,
) -> None:
# Given
caplog = get_task_processor_caplog()

@register_task_handler(transaction_on_commit=False)
def my_function(*args, **kwargs):
pass
Expand Down
32 changes: 28 additions & 4 deletions tests/unit/task_processor/test_unit_task_processor_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,18 @@ def _a_task():
)


def test_run_task_runs_task_and_creates_task_run_object_when_failure(db):
def test_run_task_runs_task_and_creates_task_run_object_when_failure(
db: None, caplog: pytest.LogCaptureFixture
) -> None:
# Given
task = Task.create(_raise_exception.task_identifier, scheduled_for=timezone.now())
task_processor_logger = logging.getLogger("task_processor")
task_processor_logger.propagate = True
task_processor_logger.level = logging.DEBUG

msg = "Error!"
task = Task.create(
_raise_exception.task_identifier, args=(msg,), scheduled_for=timezone.now()
)
task.save()

# When
Expand All @@ -229,6 +238,21 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure(db):
task.refresh_from_db()
assert not task.completed

assert len(caplog.records) == 3

warning_log = caplog.records[0]
assert warning_log.levelname == "WARNING"
assert warning_log.message == (
f"Failed to execute task '{task.task_identifier}'. Exception was: {msg}"
)

debug_log_args, debug_log_kwargs = caplog.records[1:]
assert debug_log_args.levelname == "DEBUG"
assert debug_log_args.message == f"args: ['{msg}']"

assert debug_log_kwargs.levelname == "DEBUG"
assert debug_log_kwargs.message == "kwargs: {}"


def test_run_task_runs_failed_task_again(db):
# Given
Expand Down Expand Up @@ -440,8 +464,8 @@ def _dummy_task(key: str = DEFAULT_CACHE_KEY, value: str = DEFAULT_CACHE_VALUE):


@register_task_handler()
def _raise_exception():
raise Exception()
def _raise_exception(msg: str):
raise Exception(msg)


@register_task_handler()
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/task_processor/test_unit_task_processor_threads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
from typing import Type

import pytest
from django.db import DatabaseError
from pytest_mock import MockerFixture

from task_processor.threads import TaskRunner
from tests.unit.task_processor.conftest import GetTaskProcessorCaplog


@pytest.mark.parametrize(
"exception_class, exception_message",
[(DatabaseError, "Database error"), (Exception, "Generic error")],
)
def test_task_runner_is_resilient_to_errors(
db: None,
mocker: MockerFixture,
get_task_processor_caplog: GetTaskProcessorCaplog,
exception_class: Type[Exception],
exception_message: str,
) -> None:
# Given
caplog = get_task_processor_caplog(logging.DEBUG)

task_runner = TaskRunner()
mocker.patch(
"task_processor.threads.run_tasks",
side_effect=exception_class(exception_message),
)

# When
task_runner.run_iteration()

# Then
assert len(caplog.records) == 1

assert caplog.records[0].levelno == logging.ERROR
assert (
caplog.records[0].message
== f"Received error retrieving tasks: {exception_message}."
)

0 comments on commit 5f43c67

Please sign in to comment.