Skip to content

Commit

Permalink
Merge pull request #330 from RolnickLab/feat/precalculate-values
Browse files Browse the repository at this point in the history
Cache and filter counts & scores. Improve determination calculation.
  • Loading branch information
mihow authored Dec 3, 2023
2 parents fd83bc3 + 687380d commit c32b948
Show file tree
Hide file tree
Showing 24 changed files with 377 additions and 179 deletions.
10 changes: 10 additions & 0 deletions ami/base/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from django.db.models import F, OrderBy
from rest_framework.filters import OrderingFilter


class NullsLastOrderingFilter(OrderingFilter):
def get_ordering(self, request, queryset, view):
values = super().get_ordering(request, queryset, view)
if not values:
return values
return [OrderBy(F(value.lstrip("-")), descending=value.startswith("-"), nulls_last=True) for value in values]
4 changes: 4 additions & 0 deletions ami/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,9 @@ def save_async(self, *args, **kwargs):
"""Save the model in a background task."""
ami.tasks.model_task.delay(self.__class__.__name__, self.pk, "save", *args, **kwargs)

def update_calculated_fields(self, *args, **kwargs):
"""Update calculated fields specific to each model."""
pass

class Meta:
abstract = True
2 changes: 1 addition & 1 deletion ami/jobs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def enqueue(self):
self.finished_at = None
self.scheduled_at = datetime.datetime.now()
self.status = run_job.AsyncResult(task_id).status
self.save()
self.save(force_update=True)

def setup(self, save=True):
"""
Expand Down
2 changes: 2 additions & 0 deletions ami/jobs/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class Meta:
"details",
"name",
"delay",
"limit",
"shuffle",
"project",
"project_id",
"deployment",
Expand Down
5 changes: 5 additions & 0 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from rest_framework.decorators import action
from rest_framework.response import Response

Expand All @@ -7,6 +9,8 @@
from .models import Job
from .serializers import JobListSerializer, JobSerializer

logger = logging.getLogger(__name__)


class JobViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -89,6 +93,7 @@ def perform_create(self, serializer):
"""
If the ``start_now`` parameter is passed, enqueue the job immediately.
"""

job: Job = serializer.save() # type: ignore
if url_boolean_param(self.request, "start_now", default=False):
# job.run()
Expand Down
15 changes: 6 additions & 9 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from django.db.models import Count, QuerySet
from django.db.models import QuerySet
from rest_framework import serializers

from ami.base.serializers import DefaultSerializer, get_current_user, reverse_with_params
Expand Down Expand Up @@ -90,10 +90,6 @@ class DeploymentListSerializer(DefaultSerializer):

class Meta:
model = Deployment
queryset = Deployment.objects.annotate(
events_count=Count("events"),
occurrences_count=Count("occurrences"),
)
fields = [
"id",
"name",
Expand Down Expand Up @@ -379,6 +375,7 @@ class Meta:
"occurrences",
"occurrence_images",
"last_detected",
"best_determination_score",
"created_at",
"updated_at",
]
Expand Down Expand Up @@ -908,15 +905,12 @@ def get_determination_details(self, obj: Occurrence):
taxon=taxon,
identification=identification,
prediction=prediction,
score=obj.determination_score(),
score=obj.determination_score,
)


class OccurrenceSerializer(OccurrenceListSerializer):
determination = CaptureTaxonSerializer(read_only=True)
determination_id = serializers.PrimaryKeyRelatedField(
write_only=True, queryset=Taxon.objects.all(), source="determination"
)
detections = DetectionNestedSerializer(many=True, read_only=True)
identifications = OccurrenceIdentificationSerializer(many=True, read_only=True)
predictions = OccurrenceClassificationSerializer(many=True, read_only=True)
Expand All @@ -932,6 +926,9 @@ class Meta:
"identifications",
"predictions",
]
read_only_fields = [
"determination_score",
]


class EventCaptureNestedSerializer(DefaultSerializer):
Expand Down
53 changes: 33 additions & 20 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from rest_framework import exceptions as api_exceptions
from rest_framework import permissions, viewsets
from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.filters import SearchFilter
from rest_framework.response import Response
from rest_framework.views import APIView

from ami import tasks
from ami.base.filters import NullsLastOrderingFilter

from ..models import (
DEFAULT_CONFIDENCE_THRESHOLD,
Classification,
Deployment,
Detection,
Expand Down Expand Up @@ -78,7 +80,7 @@
class DefaultViewSetMixin:
filter_backends = [
DjangoFilterBackend,
OrderingFilter,
NullsLastOrderingFilter,
SearchFilter,
]
filterset_fields = []
Expand Down Expand Up @@ -119,16 +121,7 @@ class DeploymentViewSet(DefaultViewSet):
for the list and detail views.
"""

queryset = Deployment.objects.annotate(
events_count=models.Count("events", distinct=True),
occurrences_count=models.Count("occurrences", distinct=True),
taxa_count=models.Count("occurrences__determination", distinct=True),
captures_count=models.Count("events__captures", distinct=True),
# The first and last date should come from the captures,
# but it may be much slower to query.
first_date=models.Min("events__start__date"),
last_date=models.Max("events__end__date"),
).select_related("project")
queryset = Deployment.objects.select_related("project")
filterset_fields = ["project"]
ordering_fields = [
"created_at",
Expand Down Expand Up @@ -447,6 +440,7 @@ class OccurrenceViewSet(DefaultViewSet):
"event",
)
.prefetch_related("detections")
.order_by("-determination_score")
.all()
)
serializer_class = OccurrenceSerializer
Expand All @@ -459,6 +453,7 @@ class OccurrenceViewSet(DefaultViewSet):
"duration",
"deployment",
"determination",
"determination_score",
"event",
"detections_count",
]
Expand Down Expand Up @@ -495,6 +490,7 @@ class TaxonViewSet(DefaultViewSet):
"occurrences_count",
"detections_count",
"last_detected",
"best_determination_score",
"name",
]
search_fields = ["name", "parent__name"]
Expand Down Expand Up @@ -557,18 +553,29 @@ def filter_by_occurrence(self, queryset: QuerySet) -> QuerySet:

if occurrence_id:
occurrence = Occurrence.objects.get(id=occurrence_id)
# This query does not need the same filtering as the others
return queryset.filter(occurrences=occurrence).distinct()
elif project_id:
project = Project.objects.get(id=project_id)
return super().get_queryset().filter(occurrences__project=project).distinct()
queryset = super().get_queryset().filter(occurrences__project=project)
elif deployment_id:
deployment = Deployment.objects.get(id=deployment_id)
return super().get_queryset().filter(occurrences__deployment=deployment).distinct()
queryset = super().get_queryset().filter(occurrences__deployment=deployment)
elif event_id:
event = Event.objects.get(id=event_id)
return super().get_queryset().filter(occurrences__event=event).distinct()
else:
return queryset
queryset = super().get_queryset().filter(occurrences__event=event)

queryset = (
queryset.annotate(best_determination_score=models.Max("occurrences__determination_score"))
.filter(best_determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD)
.distinct()
)

# If ordering is not specified, order by best determination score
if not self.request.query_params.get("ordering"):
queryset = queryset.order_by("-best_determination_score")

return queryset

def get_queryset(self) -> QuerySet:
qs = super().get_queryset()
Expand Down Expand Up @@ -621,10 +628,16 @@ def get(self, request):
"events_count": Event.objects.filter(deployment__project=project).count(),
"captures_count": SourceImage.objects.filter(deployment__project=project).count(),
"detections_count": Detection.objects.filter(occurrence__project=project).count(),
"occurrences_count": Occurrence.objects.filter(project=project).count(),
"occurrences_count": Occurrence.objects.filter(
project=project,
determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD,
).count(),
"taxa_count": Taxon.objects.annotate(occurrences_count=models.Count("occurrences"))
.filter(occurrences_count__gt=0)
.filter(occurrences__project=project)
.filter(
occurrences_count__gt=0,
occurrences__determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD,
occurrences__project=project,
)
.distinct()
.count(),
}
Expand Down
47 changes: 47 additions & 0 deletions ami/main/migrations/0024_deployment_captures_count_and_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Generated by Django 4.2.2 on 2023-12-01 21:42

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0023_taxon_main_taxon_orderin_4ffb7b_idx"),
]

operations = [
migrations.AddField(
model_name="deployment",
name="captures_count",
field=models.IntegerField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="detections_count",
field=models.IntegerField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="events_count",
field=models.IntegerField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="first_capture_timestamp",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="last_capture_timestamp",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="occurrences_count",
field=models.IntegerField(blank=True, null=True),
),
migrations.AddField(
model_name="deployment",
name="taxa_count",
field=models.IntegerField(blank=True, null=True),
),
]
27 changes: 27 additions & 0 deletions ami/main/migrations/0025_update_deployment_aggregates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Generated by Django 4.2.2 on 2023-12-01 21:43

from django.db import migrations
import logging

logger = logging.getLogger(__name__)


# Save all Deployment objects to update their calculated fields.
def update_deployment_aggregates(apps, schema_editor):
# Deployment = apps.get_model("main", "Deployment")
from ami.main.models import Deployment

for deployment in Deployment.objects.all():
logger.info(f"Updating deployment {deployment}")
deployment.save(update_calculated_fields=True)


class Migration(migrations.Migration):
dependencies = [
("main", "0024_deployment_captures_count_and_more"),
]

# operations = []
operations = [
migrations.RunPython(update_deployment_aggregates, migrations.RunPython.noop),
]
17 changes: 17 additions & 0 deletions ami/main/migrations/0026_occurrence_determination_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 4.2.2 on 2023-12-02 01:08

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("main", "0025_update_deployment_aggregates"),
]

operations = [
migrations.AddField(
model_name="occurrence",
name="determination_score",
field=models.FloatField(blank=True, null=True),
),
]
22 changes: 22 additions & 0 deletions ami/main/migrations/0027_update_occurrence_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Generated by Django 4.2.2 on 2023-12-02 01:08

from django.db import migrations


# Call save on all occurrences to update their scores
def update_occurrence_scores(apps, schema_editor):
# Occurrence = apps.get_model("main", "Occurrence")
from ami.main.models import Occurrence

for occurrence in Occurrence.objects.all():
occurrence.save()


class Migration(migrations.Migration):
dependencies = [
("main", "0026_occurrence_determination_score"),
]

operations = [
migrations.RunPython(update_occurrence_scores, migrations.RunPython.noop),
]
Loading

0 comments on commit c32b948

Please sign in to comment.