From e72e4ec0993b53c565f58fb0bef7ae7b1439b431 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 19 Feb 2024 21:08:24 -0800 Subject: [PATCH] Fixes and tests for backgound Jobs (#358) * Add tests for jobs, attempt to fix JobNotFound integrity error https://amii.sentry.io/share/issue/512c752a224641a8b046a82c1d0fd9c5/ * Don't report shuffling for single images * Skip job retry if not found, to help debug this issue * Fix unauthenticated request --- ami/jobs/models.py | 11 +-- ami/jobs/tasks.py | 3 +- ami/jobs/tests.py | 92 ++++++++++++++++++- ami/jobs/views.py | 7 +- ami/main/api/views.py | 6 +- .../hooks/captures/useStarCapture.ts | 2 +- 6 files changed, 107 insertions(+), 14 deletions(-) diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 9fadd54f4..c3d478748 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -304,7 +304,8 @@ def enqueue(self): self.finished_at = None self.scheduled_at = datetime.datetime.now() self.status = run_job.AsyncResult(task_id).status - self.save(force_update=True) + self.update_progress(save=False) + self.save() def setup(self, save=True): """ @@ -322,7 +323,7 @@ def setup(self, save=True): self.progress.add_stage_param(collect_stage.key, "Total Images", "") pipeline_stage = self.progress.add_stage("Process") - self.progress.add_stage_param(pipeline_stage.key, "Proccessed", "") + self.progress.add_stage_param(pipeline_stage.key, "Processed", "") self.progress.add_stage_param(pipeline_stage.key, "Remaining", "") self.progress.add_stage_param(pipeline_stage.key, "Detections", "") self.progress.add_stage_param(pipeline_stage.key, "Classifications", "") @@ -330,8 +331,6 @@ def setup(self, save=True): saving_stage = self.progress.add_stage("Results") self.progress.add_stage_param(saving_stage.key, "Objects created", "") - self.save() - if save: self.save() @@ -393,7 +392,7 @@ def run(self): source_image_count = len(images) self.progress.update_stage("collect", total_images=source_image_count) - if self.shuffle: + if self.shuffle and source_image_count > 1: self.logger.info("Shuffling images") random.shuffle(images) @@ -505,7 +504,7 @@ def update_progress(self, save=True): Update the total aggregate progress from the progress of each stage. """ if not len(self.progress.stages): - total_progress = 0 + total_progress = 1 else: for stage in self.progress.stages: if stage.status == JobState.SUCCESS and stage.progress < 1: diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index 7a491c81c..b12271178 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -16,7 +16,8 @@ def run_job(self, job_id: int) -> None: try: job = Job.objects.get(pk=job_id) except Job.DoesNotExist as e: - self.retry(exc=e, countdown=1, max_retries=1) + raise e + # self.retry(exc=e, countdown=1, max_retries=1) else: job.logger.info(f"Running job {job}") try: diff --git a/ami/jobs/tests.py b/ami/jobs/tests.py index 1779f52ae..870d4ffb8 100644 --- a/ami/jobs/tests.py +++ b/ami/jobs/tests.py @@ -1,10 +1,15 @@ from django.test import TestCase +from rest_framework.test import APIRequestFactory, APITestCase +from ami.base.serializers import reverse_with_params from ami.jobs.models import Job, JobProgress, JobState from ami.main.models import Project +from ami.users.models import User +# from rich import print -class TestJobProgres(TestCase): + +class TestJobProgress(TestCase): def setUp(self): self.project = Project.objects.create(name="Job test") @@ -31,3 +36,88 @@ def test_create_job_with_delay(self): self.assertEqual(job.progress.summary.status, JobState.SUCCESS) self.assertEqual(job.progress.stages[0].progress, 1) self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS) + + +class TestJobView(APITestCase): + """ + Test the jobs API endpoints. + """ + + def setUp(self): + self.project = Project.objects.create(name="Jobs Test Project") + self.job = Job.objects.create(project=self.project, name="Test job", delay=0) + self.user = User.objects.create_user( # type: ignore + email="testuser@insectai.org", + ) + self.factory = APIRequestFactory() + + def test_get_job(self): + # resp = self.client.get(f"/api/jobs/{self.job.pk}/") + jobs_retrieve_url = reverse_with_params("api:job-detail", args=[self.job.pk]) + resp = self.client.get(jobs_retrieve_url) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()["id"], self.job.pk) + + def test_get_job_list(self): + # resp = self.client.get("/api/jobs/") + jobs_list_url = reverse_with_params("api:job-list") + resp = self.client.get(jobs_list_url) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp.json()["count"], 1) + + def test_create_job_unauthenticated(self): + jobs_create_url = reverse_with_params("api:job-list") + job_data = { + "project_id": self.project.pk, + "name": "Test job unauthenticated", + "delay": 0, + } + self.client.force_authenticate(user=None) + resp = self.client.post(jobs_create_url, job_data) + self.assertEqual(resp.status_code, 403) + + def test_create_job(self): + jobs_create_url = reverse_with_params("api:job-list") + # request = self.factory.post(jobs_create_url, {"project": self.project.pk, "name": "Test job 2"}) + self.client.force_authenticate(user=self.user) + job_data = { + "project_id": self.project.pk, + "name": "Test job 2", + "delay": 0, + } + resp = self.client.post(jobs_create_url, job_data) + self.client.force_authenticate(user=None) + self.assertEqual(resp.status_code, 201) + data = resp.json() + self.assertEqual(data["project"]["id"], self.project.pk) + self.assertEqual(data["name"], "Test job 2") + # self.assertEqual(data["progress"]["status"], "CREATED") + progress = JobProgress(**data["progress"]) + self.assertEqual(progress.summary.status, JobState.CREATED) + + def test_run_job(self): + jobs_run_url = reverse_with_params("api:job-run", args=[self.job.pk], params={"no_async": True}) + self.client.force_authenticate(user=self.user) + resp = self.client.post(jobs_run_url) + self.client.force_authenticate(user=None) + self.assertEqual(resp.status_code, 200) + data = resp.json() + self.assertEqual(data["id"], self.job.pk) + self.assertEqual(data["status"], JobState.SUCCESS.value) + progress = JobProgress(**data["progress"]) + self.assertEqual(progress.summary.status, JobState.SUCCESS) + self.assertEqual(progress.summary.progress, 1.0) + # self.job.refresh_from_db() + # Assert has a task id now, if async is working in tests + # self.assertIsNotNone(self.job.task_id) + + def test_run_job_unauthenticated(self): + jobs_run_url = reverse_with_params("api:job-run", args=[self.job.pk]) + self.client.force_authenticate(user=None) + resp = self.client.post(jobs_run_url) + self.assertEqual(resp.status_code, 403) + + def test_cancel_job(self): + # This cannot be tested until we have a way to cancel jobs + # and a way to run async tasks in tests. + pass diff --git a/ami/jobs/views.py b/ami/jobs/views.py index ca8bb1595..e7e932854 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -74,8 +74,11 @@ def run(self, request, pk=None): Run a job (add it to the queue). """ job: Job = self.get_object() - # job.run() - job.enqueue() + no_async = url_boolean_param(request, "no_async", default=False) + if no_async: + job.run() + else: + job.enqueue() job.refresh_from_db() return Response(self.get_serializer(job).data) diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 6fd475e95..244a1bc3b 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -98,7 +98,7 @@ class DefaultViewSetMixin: filterset_fields = [] ordering_fields = ["created_at", "updated_at"] search_fields = [] - permission_classes = [permissions.AllowAny] + permission_classes = [permissions.IsAuthenticatedOrReadOnly] class DefaultViewSet(DefaultViewSetMixin, viewsets.ModelViewSet): @@ -693,7 +693,7 @@ class ClassificationViewSet(DefaultViewSet): class SummaryView(APIView): - permission_classes = [permissions.AllowAny] + permission_classes = [permissions.IsAuthenticatedOrReadOnly] filterset_fields = ["project"] def get(self, request): @@ -770,7 +770,7 @@ class StorageStatus(APIView): Return the status of the storage connection. """ - permission_classes = [permissions.AllowAny] + permission_classes = [permissions.IsAuthenticatedOrReadOnly] serializer_class = StorageStatusSerializer def post(self, request): diff --git a/ui/src/data-services/hooks/captures/useStarCapture.ts b/ui/src/data-services/hooks/captures/useStarCapture.ts index fde0bc32e..f847255ac 100644 --- a/ui/src/data-services/hooks/captures/useStarCapture.ts +++ b/ui/src/data-services/hooks/captures/useStarCapture.ts @@ -14,7 +14,7 @@ export const useStarCapture = (id: string, isStarred: boolean, onSuccess?: () => const { mutateAsync, isLoading, isSuccess, error } = useMutation({ mutationFn: () => - axios.post(mutationUrl, { + axios.post(mutationUrl, {}, { headers: getAuthHeader(user), }), onSuccess: () => {