Skip to content

Commit

Permalink
feat: continue splitting Taxa and TaxaObserved models
Browse files Browse the repository at this point in the history
  • Loading branch information
mihow committed Sep 4, 2024
1 parent d82fb73 commit 6c65430
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 42 deletions.
3 changes: 3 additions & 0 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,12 +458,15 @@ class Meta:


class TaxonListSerializer(DefaultSerializer):
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")

class Meta:
model = Taxon
fields = [
"id",
"name",
"rank",
"parents",
"details",
"created_at",
"updated_at",
Expand Down
6 changes: 6 additions & 0 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,7 @@ def get_serializer_class(self):
else:
return TaxonSerializer

# @TODO this can now be removed since we are using TaxonObservedViewSet
def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
"""
Filter taxa by when/where it has occurred.
Expand Down Expand Up @@ -857,6 +858,7 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
# @TODO need to return the models.Q filter used, so we can use it for counts and related occurrences.
return queryset, filter_active

# @TODO this can now be removed since we are using TaxonObservedViewSet
def filter_by_classification_threshold(self, queryset: QuerySet) -> QuerySet:
"""
Filter taxa by their best determination score in occurrences.
Expand All @@ -876,6 +878,7 @@ def filter_by_classification_threshold(self, queryset: QuerySet) -> QuerySet:

return queryset

# @TODO this can now be removed since we are using TaxonObservedViewSet
def get_occurrences_filters(self, queryset: QuerySet) -> tuple[QuerySet, models.Q]:
# @TODO this should check what the user has access to
project_id = self.request.query_params.get("project")
Expand All @@ -901,6 +904,7 @@ def get_occurrences_filters(self, queryset: QuerySet) -> tuple[QuerySet, models.

return taxon_occurrences_query, taxon_occurrences_count_filter

# @TODO this can now be removed since we are using TaxonObservedViewSet
def add_occurrence_counts(self, queryset: QuerySet, occurrences_count_filter: models.Q) -> QuerySet:
qs = queryset.annotate(
occurrences_count=models.Count(
Expand All @@ -912,10 +916,12 @@ def add_occurrence_counts(self, queryset: QuerySet, occurrences_count_filter: mo
)
return qs

# @TODO this can now be removed since we are using TaxonObservedViewSet
def add_filtered_occurrences(self, queryset: QuerySet, occurrences_query: QuerySet) -> QuerySet:
qs = queryset.prefetch_related(Prefetch("occurrences", queryset=occurrences_query))
return qs

# @TODO this can now be removed since we are using TaxonObservedViewSet
def zero_occurrences(self, queryset: QuerySet) -> QuerySet:
"""
Return a queryset with zero occurrences but compatible with the original queryset.
Expand Down
138 changes: 102 additions & 36 deletions ami/taxa/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,71 @@
import typing as t

from django.conf import settings
from django.contrib.postgres.aggregates import ArrayAgg
from django.contrib.postgres.fields import ArrayField
from django.db import models, transaction
from django.utils import timezone

from ami.base.models import BaseModel, update_calculated_fields_in_bulk
from ami.main.models import Classification, Detection, Occurrence, Project, Taxon
from ami.utils.storages import get_temporary_media_url

# from ami.utils.storages import get_temporary_media_url

logger = logging.getLogger(__name__)


class TaxonObservedQuerySet(models.QuerySet):
def with_occurrence_images(self, limit: int = 10, classification_threshold: float | None = None):
classification_threshold = classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD

# Subquery to get the top N detections for each taxon and project
top_detections = (
Detection.objects.filter(
occurrence__determination=models.OuterRef("taxon"),
occurrence__project=models.OuterRef("project"),
classifications__score__gte=classification_threshold,
)
.order_by("-classifications__score")
.values("path")
.distinct()[:limit]
)

# Annotate the main queryset with the array of image paths
return self.annotate(
occurrence_images=models.Subquery(
top_detections.values("occurrence__determination")
.annotate(paths=ArrayAgg("path"))
.values("paths")[:1],
output_field=ArrayField(models.CharField(max_length=255)),
)
)

def with_occurrences(self, limit: int = 10):
return self.prefetch_related(
models.Prefetch(
"occurrences",
queryset=Occurrence.objects.order_by("-created_at")[:limit],
to_attr="top_occurrences",
)
)
# taxon_occurrences_query = (
# Occurrence.objects.filter(
# determination_score__gte=get_active_classification_threshold(self.request),
# event__isnull=False,
# )
# .distinct()
# .annotate(
# first_appearance_timestamp=models.Min("detections__timestamp"),
# last_appearance_timestamp=models.Max("detections__timestamp"),
# )
# .order_by("-first_appearance_timestamp")


class TaxonObservedManager(models.Manager):
def get_queryset(self) -> TaxonObservedQuerySet:
return TaxonObservedQuerySet(self.model, using=self._db).select_related("taxon", "project")


@t.final
class TaxonObserved(BaseModel):
"""
Expand All @@ -35,6 +90,8 @@ class TaxonObserved(BaseModel):
last_detected = models.DateTimeField(null=True, blank=True)
calculated_fields_updated_at = models.DateTimeField(blank=True, null=True)

objects = TaxonObservedManager.from_queryset(TaxonObservedQuerySet)()

class Meta:
ordering = ["-last_detected"]
verbose_name_plural = "Taxa Observed"
Expand Down Expand Up @@ -79,47 +136,56 @@ def get_last_detected(self) -> datetime.datetime | None:
.first()
)

def occurrence_images(
self,
limit: int | None = 10,
project_id: int | None = None,
classification_threshold: float | None = None,
) -> list[str]:
def occurrence_images(self) -> list[str]:
"""
Return one image from each occurrence of this Taxon.
The image should be from the detection with the highest classification score.
This is used for image thumbnail previews in the species summary view.
The project ID is an optional filter however
@TODO important, this should always filter by what the current user has access to.
Use the request.user to filter by the user's access.
Use the request to generate the full media URLs.
"""

classification_threshold = classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD

# Retrieve the URLs using a single optimized query
qs = (
self.occurrences.prefetch_related(
models.Prefetch(
"detections__classifications",
queryset=Classification.objects.filter(score__gte=classification_threshold).order_by("-score"),
)
)
.annotate(max_score=models.Max("detections__classifications__score"))
.filter(detections__classifications__score=models.F("max_score"))
.order_by("-max_score")
)
if project_id is not None:
# @TODO this should check the user's access instead
qs = qs.filter(project=project_id)

detection_image_paths = qs.values_list("detections__path", flat=True)[:limit]

# @TODO should this be done in the serializer?
# @TODO better way to get distinct values from an annotated queryset?
return [get_temporary_media_url(path) for path in detection_image_paths if path]
return []

# def occurrence_images(
# self,
# limit: int | None = 10,
# project_id: int | None = None,
# classification_threshold: float | None = None,
# ) -> list[str]:
# """
# Return one image from each occurrence of this Taxon.
# The image should be from the detection with the highest classification score.

# This is used for image thumbnail previews in the species summary view.

# The project ID is an optional filter however
# @TODO important, this should always filter by what the current user has access to.
# Use the request.user to filter by the user's access.
# Use the request to generate the full media URLs.
# """

# classification_threshold = classification_threshold or settings.DEFAULT_CONFIDENCE_THRESHOLD

# # Retrieve the URLs using a single optimized query
# qs = (
# self.occurrences.prefetch_related(
# models.Prefetch(
# "detections__classifications",
# queryset=Classification.objects.filter(score__gte=classification_threshold).order_by("-score"),
# )
# )
# .annotate(max_score=models.Max("detections__classifications__score"))
# .filter(detections__classifications__score=models.F("max_score"))
# .order_by("-max_score")
# )
# if project_id is not None:
# # @TODO this should check the user's access instead
# qs = qs.filter(project=project_id)

# detection_image_paths = qs.values_list("detections__path", flat=True)[:limit]

# # @TODO should this be done in the serializer?
# # @TODO better way to get distinct values from an annotated queryset?
# return [get_temporary_media_url(path) for path in detection_image_paths if path]

def update_calculated_fields(self, save=True, updated_timestamp: datetime.datetime | None = None):
"""
Expand Down
24 changes: 21 additions & 3 deletions ami/taxa/serializers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from ami.base.serializers import DefaultSerializer
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer
from ami.main.api.serializers import OccurrenceNestedSerializer, TaxonSerializer
from ami.main.models import Detection
from ami.taxa.models import TaxonObserved

MinimalDetectionNestedSerializer = MinimalNestedModelSerializer.create_for_model(Detection)


class TaxonObservedListSerializer(DefaultSerializer):
# occurrences = DefaultSerializer(many=True, read_only=True, source="top_occurrences")
# best_detection = MinimalNestedModelSerializer(source="best_detection_id", read_only=True)
taxon = TaxonSerializer()

class TaxonObservedSerializer(DefaultSerializer):
class Meta:
model = TaxonObserved
fields = [
"id",
"details",
"taxon",
"project",
"detections_count",
"occurrences_count",
"best_detection",
"best_determination_score",
"last_detected",
"created_at",
"updated_at",
"occurrence_images",
]


class TaxonObservedSerializer(TaxonObservedListSerializer):
occurrences = OccurrenceNestedSerializer(many=True, read_only=True, source="top_occurrences")

class Meta(TaxonObservedListSerializer.Meta):
fields = TaxonObservedListSerializer.Meta.fields + [
"occurrences",
]
30 changes: 27 additions & 3 deletions ami/taxa/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ami.main.api.views import DefaultViewSet
from ami.taxa.models import TaxonObserved
from ami.taxa.serializers import TaxonObservedSerializer
from ami.taxa.serializers import TaxonObservedListSerializer, TaxonObservedSerializer

logger = logging.getLogger(__name__)

Expand All @@ -12,5 +12,29 @@ class TaxonObservedViewSet(DefaultViewSet):
Endpoint for taxa information that have been observed in a project.
"""

queryset = TaxonObserved.objects.all()
serializer_class = TaxonObservedSerializer
ordering_fields = [
"id",
"taxon__name",
"detections_count",
"occurrences_count",
"best_determination_score",
"last_detected",
"created_at",
"updated_at",
]

queryset = TaxonObserved.objects.all().select_related("taxon", "project")

def get_queryset(self):
qs = super().get_queryset()
if self.action == "list":
return qs.with_occurrence_images(classification_threshold=0)
elif self.action == "retrieve":
return qs.with_occurrences().with_occurrence_images(classification_threshold=0)

def get_serializer_class(self):
if self.action == "list":
return TaxonObservedListSerializer
return TaxonObservedSerializer

# Set plural name for the viewset list name

0 comments on commit 6c65430

Please sign in to comment.