Skip to content

Commit

Permalink
Merge branch 'main' into feat/sort-by-physical-size
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow authored Apr 23, 2024
2 parents 14fa31c + 8593825 commit 006b2d3
Show file tree
Hide file tree
Showing 36 changed files with 622 additions and 102 deletions.
2 changes: 1 addition & 1 deletion ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def setUp(self):
self.job = Job.objects.create(project=self.project, name="Test job", delay=0)
self.user = User.objects.create_user( # type: ignore
email="[email protected]",
is_staff=True,
)
self.factory = APIRequestFactory()

Expand Down Expand Up @@ -99,7 +100,6 @@ def test_run_job(self):
jobs_run_url = reverse_with_params("api:job-run", args=[self.job.pk], params={"no_async": True})
self.client.force_authenticate(user=self.user)
resp = self.client.post(jobs_run_url)
self.client.force_authenticate(user=None)
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["id"], self.job.pk)
Expand Down
36 changes: 33 additions & 3 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ class Meta:
"withdrawn",
"agreed_with_identification_id",
"agreed_with_prediction_id",
"comment",
"created_at",
"updated_at",
]
Expand Down Expand Up @@ -706,7 +707,7 @@ class Meta:
"classifications",
]

def get_classifications(self, obj):
def get_classifications(self, obj) -> str:
"""
Return URL to the classifications endpoint filtered by this detection.
"""
Expand Down Expand Up @@ -899,10 +900,38 @@ def validate_image(self, value):
return value


class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
# The most common kwargs for the sampling methods
# use for the "common_combined" method
minute_interval = serializers.IntegerField(required=False, allow_null=True)
max_num = serializers.IntegerField(required=False, allow_null=True)
month_start = serializers.IntegerField(required=False, allow_null=True)
month_end = serializers.IntegerField(required=False, allow_null=True)

hour_start = serializers.IntegerField(required=False, allow_null=True)
hour_end = serializers.IntegerField(required=False, allow_null=True)

# Kwargs for other sampling methods, this is not complete
# see the SourceImageCollection model for all available kwargs.
size = serializers.IntegerField(required=False, allow_null=True)
num_each = serializers.IntegerField(required=False, allow_null=True)
exclude_events = serializers.CharField(required=False, allow_null=True)
deployment_id = serializers.IntegerField(required=False, allow_null=True)
position = serializers.IntegerField(required=False, allow_null=True)

# Don't return the kwargs if they are empty
def to_representation(self, instance):
data = super().to_representation(instance)
return {key: value for key, value in data.items() if value is not None}


class SourceImageCollectionSerializer(DefaultSerializer):
# @TODO can sampling kwargs be a nested serializer instead??

source_images = serializers.SerializerMethodField()
kwargs = serializers.JSONField(initial=dict, required=False)
kwargs = SourceImageCollectionCommonKwargsSerializer(required=False, partial=True)
jobs = JobStatusSerializer(many=True, read_only=True)
project = serializers.PrimaryKeyRelatedField(queryset=Project.objects.all())

class Meta:
model = SourceImageCollection
Expand All @@ -920,7 +949,7 @@ class Meta:
"updated_at",
]

def get_source_images(self, obj):
def get_source_images(self, obj) -> str:
"""
Return URL to the captures endpoint filtered by this collection.
"""
Expand All @@ -944,6 +973,7 @@ class Meta:
"taxon",
"user",
"withdrawn",
"comment",
"created_at",
]

Expand Down
10 changes: 9 additions & 1 deletion ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.filters import SearchFilter
from rest_framework.generics import GenericAPIView
from rest_framework.response import Response
from rest_framework.views import APIView

from ami import tasks
from ami.base.filters import NullsLastOrderingFilter
from ami.base.pagination import LimitOffsetPaginationWithPermissions
from ami.base.permissions import IsActiveStaffOrReadOnly
from ami.utils.requests import get_active_classification_threshold

Expand Down Expand Up @@ -98,13 +100,18 @@ class DefaultReadOnlyViewSet(DefaultViewSetMixin, viewsets.ReadOnlyModelViewSet)
pass


class ProjectPagination(LimitOffsetPaginationWithPermissions):
default_limit = 20


class ProjectViewSet(DefaultViewSet):
"""
API endpoint that allows projects to be viewed or edited.
"""

queryset = Project.objects.filter(active=True).prefetch_related("deployments").all()
serializer_class = ProjectSerializer
pagination_class = ProjectPagination

def get_serializer_class(self):
"""
Expand Down Expand Up @@ -317,6 +324,7 @@ def populate(self, request, pk=None):
Populate a collection with source images using the configured sampling method and arguments.
"""
collection = self.get_object()
collection.images.clear()
task = tasks.populate_collection.apply_async([collection.pk])
return Response({"task": task.id})

Expand Down Expand Up @@ -702,7 +710,7 @@ class ClassificationViewSet(DefaultViewSet):
]


class SummaryView(APIView):
class SummaryView(GenericAPIView):
permission_classes = [IsActiveStaffOrReadOnly]
filterset_fields = ["project"]

Expand Down
17 changes: 17 additions & 0 deletions ami/main/migrations/0030_identification_comment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 4.2.2 on 2024-04-16 18:56

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0029_alter_deployment_device_and_more"),
]

operations = [
migrations.AddField(
model_name="identification",
name="comment",
field=models.TextField(blank=True),
),
]
49 changes: 46 additions & 3 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,26 +1133,31 @@ def set_dimensions_for_collection(


def sample_captures_by_interval(
minute_interval: int = 10, qs: models.QuerySet[SourceImage] | None = None
minute_interval: int = 10, qs: models.QuerySet[SourceImage] | None = None, max_num: int | None = None
) -> typing.Generator[SourceImage, None, None]:
"""
Return a sample of captures from the deployment, evenly spaced apart by minute_interval.
"""

last_capture = None
total = 0

if not qs:
qs = SourceImage.objects.all()
qs = qs.exclude(timestamp=None).order_by("timestamp")

for capture in qs.all():
if max_num and total >= max_num:
break
if not last_capture:
total += 1
yield capture
last_capture = capture
else:
assert capture.timestamp and last_capture.timestamp
delta: datetime.timedelta = capture.timestamp - last_capture.timestamp
if delta.total_seconds() >= minute_interval * 60:
total += 1
yield capture
last_capture = capture

Expand Down Expand Up @@ -1285,6 +1290,7 @@ class Identification(BaseModel):
related_name="agreed_identifications",
)
score = 1.0 # Always 1 for humans, at this time
comment = models.TextField(blank=True)

class Meta:
ordering = [
Expand Down Expand Up @@ -1390,6 +1396,10 @@ class Classification(BaseModel):
)
# job = models.CharField(max_length=255, null=True)

# Type hints for auto-generated fields
taxon_id: int
algorithm_id: int

class Meta:
ordering = ["-created_at", "-score"]

Expand Down Expand Up @@ -2132,6 +2142,7 @@ def html(self) -> str:


_SOURCE_IMAGE_SAMPLING_METHODS = [
"common_combined",
"random",
"stratified_random",
"interval",
Expand Down Expand Up @@ -2162,7 +2173,12 @@ class SourceImageCollection(BaseModel):
description = models.TextField(blank=True)
images = models.ManyToManyField("SourceImage", related_name="collections", blank=True)
project = models.ForeignKey(Project, on_delete=models.CASCADE, related_name="sourceimage_collections")
method = models.CharField(max_length=255, choices=as_choices(_SOURCE_IMAGE_SAMPLING_METHODS))
method = models.CharField(
max_length=255,
choices=as_choices(_SOURCE_IMAGE_SAMPLING_METHODS),
default="common_combined",
)
# @TODO this should be a JSON field with a schema, use a pydantic model
kwargs = models.JSONField(
"Arguments",
null=True,
Expand All @@ -2171,7 +2187,7 @@ class SourceImageCollection(BaseModel):
default=dict,
)

def source_image_count(self):
def source_image_count(self) -> int:
# This should always be pre-populated using queryset annotations
return self.images.count()

Expand Down Expand Up @@ -2206,6 +2222,33 @@ def sample_manual(self, image_ids: list[int]):
qs = self.get_queryset()
return qs.filter(id__in=image_ids)

def sample_common_combined(
self,
minute_interval: int | None = None,
max_num: int | None = 100,
hour_start: int | None = None,
hour_end: int | None = None,
month_start: datetime.date | None = None,
month_end: datetime.date | None = None,
) -> list[SourceImage]:
qs = self.get_queryset()
if month_start:
qs = qs.filter(timestamp__month__gte=month_start)
if month_end:
qs = qs.filter(timestamp__month__lte=month_end)
if hour_start:
qs = qs.filter(timestamp__hour__gte=hour_start)
if hour_end:
qs = qs.filter(timestamp__hour__lte=hour_end)
if minute_interval:
# @TODO can this be done in the database and return a queryset?
# this currently returns a list of source images
qs = list(sample_captures_by_interval(minute_interval, qs, max_num=max_num))
if max_num:
qs = qs[:max_num]
captures = list(qs)
return captures

def sample_interval(
self, minute_interval: int = 10, exclude_events: list[int] = [], deployment_id: int | None = None
):
Expand Down
50 changes: 50 additions & 0 deletions ami/main/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from django.db import connection
from django.test import TestCase
from rest_framework.test import APIRequestFactory, APITestCase
from rich import print

from ami.main.models import (
Expand All @@ -17,6 +18,7 @@
TaxonRank,
group_images_into_events,
)
from ami.users.models import User


def setup_test_project(reuse=True) -> tuple[Project, Deployment]:
Expand Down Expand Up @@ -547,3 +549,51 @@ def test_taxon_detail(self):
response = self.client.get(f"/api/v2/taxa/{taxon.pk}/")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["name"], taxon.name)


class TestIdentification(APITestCase):
def setUp(self) -> None:
project, deployment = setup_test_project()
create_taxa(project=project)
create_captures(deployment=deployment)
group_images_into_events(deployment=deployment)
create_occurrences(deployment=deployment, num=5)
self.project = project
self.user = User.objects.create_user( # type: ignore
email="[email protected]",
is_staff=True,
)
self.factory = APIRequestFactory()
self.client.force_authenticate(user=self.user)
return super().setUp()

def test_identification(self):
from ami.main.models import Identification, Taxon

"""
Post a new identification suggestion and check that it changed the occurrence's determination.
"""

suggest_id_endpoint = "/api/v2/identifications/"
taxa = Taxon.objects.filter(projects=self.project)
assert taxa.count() > 1

occurrence = Occurrence.objects.filter(project=self.project).exclude(determination=None)[0]
original_taxon = occurrence.determination
assert original_taxon is not None
new_taxon = Taxon.objects.exclude(pk=original_taxon.pk)[0]
comment = "Test identification comment"

response = self.client.post(
suggest_id_endpoint,
{
"occurrence_id": occurrence.pk,
"taxon_id": new_taxon.pk,
"comment": comment,
},
)
self.assertEqual(response.status_code, 201)
occurrence.refresh_from_db()
self.assertEqual(occurrence.determination, new_taxon)
identification = Identification.objects.get(pk=response.json()["id"])
self.assertEqual(identification.comment, comment)
4 changes: 2 additions & 2 deletions config/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@
"TITLE": "Automated Monitoring of Insects ML Platform API",
"DESCRIPTION": "Documentation of API endpoints of Automated Monitoring of Insects ML Platform",
"VERSION": "1.0.0",
"SERVE_PERMISSIONS": ["rest_framework.permissions.IsAdminUser"],
# "SERVE_PERMISSIONS": ["rest_framework.permissions.AllowAny"],
# "SERVE_PERMISSIONS": ["rest_framework.permissions.IsAdminUser"],
"SERVE_PERMISSIONS": ["rest_framework.permissions.AllowAny"],
}
# Your stuff...
# ------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion config/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# API base url
path("api/v2/", include("config.api_router", namespace="api")),
# OpenAPI Docs
path("api/v2/schema/", SpectacularAPIView.as_view(), name="api-schema"),
path("api/v2/schema/", SpectacularAPIView.as_view(api_version="api"), name="api-schema"),
path(
"api/v2/docs/",
SpectacularSwaggerView.as_view(url_name="api-schema"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const config: FormConfig = {
label: translate(STRING.FIELD_LABEL_NAME),
},
image: {
label: translate(STRING.FIELD_LABEL_IMAGE),
label: translate(STRING.FIELD_LABEL_ICON),
description: [
translate(STRING.MESSAGE_IMAGE_SIZE, {
value: bytesToMB(API_MAX_UPLOAD_SIZE),
Expand Down
22 changes: 22 additions & 0 deletions ui/src/data-services/hooks/collections/usePopulateCollection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { useMutation, useQueryClient } from '@tanstack/react-query'
import axios from 'axios'
import { API_ROUTES, API_URL } from 'data-services/constants'
import { getAuthHeader } from 'data-services/utils'
import { useUser } from 'utils/user/userContext'

export const usePopulateCollection = () => {
const { user } = useUser()
const queryClient = useQueryClient()

const { mutateAsync, isLoading, isSuccess, error } = useMutation({
mutationFn: (id: string) =>
axios.post<{ id: number }>(`${API_URL}/${API_ROUTES.COLLECTIONS}/${id}/populate/`, undefined, {
headers: getAuthHeader(user),
}),
onSuccess: () => {
queryClient.invalidateQueries([API_ROUTES.COLLECTIONS])
},
})

return { populateCollection: mutateAsync, isLoading, isSuccess, error }
}
Loading

0 comments on commit 006b2d3

Please sign in to comment.