Skip to content

Commit

Permalink
Fixes and tests for backgound Jobs (#358)
Browse files Browse the repository at this point in the history
* Add tests for jobs, attempt to fix JobNotFound integrity error

https://amii.sentry.io/share/issue/512c752a224641a8b046a82c1d0fd9c5/

* Don't report shuffling for single images

* Skip job retry if not found, to help debug this issue

* Fix unauthenticated request
  • Loading branch information
mihow authored Feb 20, 2024
1 parent b1d0b4f commit e72e4ec
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 14 deletions.
11 changes: 5 additions & 6 deletions ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ def enqueue(self):
self.finished_at = None
self.scheduled_at = datetime.datetime.now()
self.status = run_job.AsyncResult(task_id).status
self.save(force_update=True)
self.update_progress(save=False)
self.save()

def setup(self, save=True):
"""
Expand All @@ -322,16 +323,14 @@ def setup(self, save=True):
self.progress.add_stage_param(collect_stage.key, "Total Images", "")

pipeline_stage = self.progress.add_stage("Process")
self.progress.add_stage_param(pipeline_stage.key, "Proccessed", "")
self.progress.add_stage_param(pipeline_stage.key, "Processed", "")
self.progress.add_stage_param(pipeline_stage.key, "Remaining", "")
self.progress.add_stage_param(pipeline_stage.key, "Detections", "")
self.progress.add_stage_param(pipeline_stage.key, "Classifications", "")

saving_stage = self.progress.add_stage("Results")
self.progress.add_stage_param(saving_stage.key, "Objects created", "")

self.save()

if save:
self.save()

Expand Down Expand Up @@ -393,7 +392,7 @@ def run(self):
source_image_count = len(images)
self.progress.update_stage("collect", total_images=source_image_count)

if self.shuffle:
if self.shuffle and source_image_count > 1:
self.logger.info("Shuffling images")
random.shuffle(images)

Expand Down Expand Up @@ -505,7 +504,7 @@ def update_progress(self, save=True):
Update the total aggregate progress from the progress of each stage.
"""
if not len(self.progress.stages):
total_progress = 0
total_progress = 1
else:
for stage in self.progress.stages:
if stage.status == JobState.SUCCESS and stage.progress < 1:
Expand Down
3 changes: 2 additions & 1 deletion ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def run_job(self, job_id: int) -> None:
try:
job = Job.objects.get(pk=job_id)
except Job.DoesNotExist as e:
self.retry(exc=e, countdown=1, max_retries=1)
raise e
# self.retry(exc=e, countdown=1, max_retries=1)
else:
job.logger.info(f"Running job {job}")
try:
Expand Down
92 changes: 91 additions & 1 deletion ami/jobs/tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from django.test import TestCase
from rest_framework.test import APIRequestFactory, APITestCase

from ami.base.serializers import reverse_with_params
from ami.jobs.models import Job, JobProgress, JobState
from ami.main.models import Project
from ami.users.models import User

# from rich import print

class TestJobProgres(TestCase):

class TestJobProgress(TestCase):
def setUp(self):
self.project = Project.objects.create(name="Job test")

Expand All @@ -31,3 +36,88 @@ def test_create_job_with_delay(self):
self.assertEqual(job.progress.summary.status, JobState.SUCCESS)
self.assertEqual(job.progress.stages[0].progress, 1)
self.assertEqual(job.progress.stages[0].status, JobState.SUCCESS)


class TestJobView(APITestCase):
"""
Test the jobs API endpoints.
"""

def setUp(self):
self.project = Project.objects.create(name="Jobs Test Project")
self.job = Job.objects.create(project=self.project, name="Test job", delay=0)
self.user = User.objects.create_user( # type: ignore
email="[email protected]",
)
self.factory = APIRequestFactory()

def test_get_job(self):
# resp = self.client.get(f"/api/jobs/{self.job.pk}/")
jobs_retrieve_url = reverse_with_params("api:job-detail", args=[self.job.pk])
resp = self.client.get(jobs_retrieve_url)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json()["id"], self.job.pk)

def test_get_job_list(self):
# resp = self.client.get("/api/jobs/")
jobs_list_url = reverse_with_params("api:job-list")
resp = self.client.get(jobs_list_url)
self.assertEqual(resp.status_code, 200)
self.assertEqual(resp.json()["count"], 1)

def test_create_job_unauthenticated(self):
jobs_create_url = reverse_with_params("api:job-list")
job_data = {
"project_id": self.project.pk,
"name": "Test job unauthenticated",
"delay": 0,
}
self.client.force_authenticate(user=None)
resp = self.client.post(jobs_create_url, job_data)
self.assertEqual(resp.status_code, 403)

def test_create_job(self):
jobs_create_url = reverse_with_params("api:job-list")
# request = self.factory.post(jobs_create_url, {"project": self.project.pk, "name": "Test job 2"})
self.client.force_authenticate(user=self.user)
job_data = {
"project_id": self.project.pk,
"name": "Test job 2",
"delay": 0,
}
resp = self.client.post(jobs_create_url, job_data)
self.client.force_authenticate(user=None)
self.assertEqual(resp.status_code, 201)
data = resp.json()
self.assertEqual(data["project"]["id"], self.project.pk)
self.assertEqual(data["name"], "Test job 2")
# self.assertEqual(data["progress"]["status"], "CREATED")
progress = JobProgress(**data["progress"])
self.assertEqual(progress.summary.status, JobState.CREATED)

def test_run_job(self):
jobs_run_url = reverse_with_params("api:job-run", args=[self.job.pk], params={"no_async": True})
self.client.force_authenticate(user=self.user)
resp = self.client.post(jobs_run_url)
self.client.force_authenticate(user=None)
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["id"], self.job.pk)
self.assertEqual(data["status"], JobState.SUCCESS.value)
progress = JobProgress(**data["progress"])
self.assertEqual(progress.summary.status, JobState.SUCCESS)
self.assertEqual(progress.summary.progress, 1.0)
# self.job.refresh_from_db()
# Assert has a task id now, if async is working in tests
# self.assertIsNotNone(self.job.task_id)

def test_run_job_unauthenticated(self):
jobs_run_url = reverse_with_params("api:job-run", args=[self.job.pk])
self.client.force_authenticate(user=None)
resp = self.client.post(jobs_run_url)
self.assertEqual(resp.status_code, 403)

def test_cancel_job(self):
# This cannot be tested until we have a way to cancel jobs
# and a way to run async tasks in tests.
pass
7 changes: 5 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def run(self, request, pk=None):
Run a job (add it to the queue).
"""
job: Job = self.get_object()
# job.run()
job.enqueue()
no_async = url_boolean_param(request, "no_async", default=False)
if no_async:
job.run()
else:
job.enqueue()
job.refresh_from_db()
return Response(self.get_serializer(job).data)

Expand Down
6 changes: 3 additions & 3 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DefaultViewSetMixin:
filterset_fields = []
ordering_fields = ["created_at", "updated_at"]
search_fields = []
permission_classes = [permissions.AllowAny]
permission_classes = [permissions.IsAuthenticatedOrReadOnly]


class DefaultViewSet(DefaultViewSetMixin, viewsets.ModelViewSet):
Expand Down Expand Up @@ -693,7 +693,7 @@ class ClassificationViewSet(DefaultViewSet):


class SummaryView(APIView):
permission_classes = [permissions.AllowAny]
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filterset_fields = ["project"]

def get(self, request):
Expand Down Expand Up @@ -770,7 +770,7 @@ class StorageStatus(APIView):
Return the status of the storage connection.
"""

permission_classes = [permissions.AllowAny]
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
serializer_class = StorageStatusSerializer

def post(self, request):
Expand Down
2 changes: 1 addition & 1 deletion ui/src/data-services/hooks/captures/useStarCapture.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export const useStarCapture = (id: string, isStarred: boolean, onSuccess?: () =>

const { mutateAsync, isLoading, isSuccess, error } = useMutation({
mutationFn: () =>
axios.post(mutationUrl, {
axios.post(mutationUrl, {}, {
headers: getAuthHeader(user),
}),
onSuccess: () => {
Expand Down

0 comments on commit e72e4ec

Please sign in to comment.