diff --git a/ami/base/filters.py b/ami/base/filters.py new file mode 100644 index 000000000..40a030b59 --- /dev/null +++ b/ami/base/filters.py @@ -0,0 +1,10 @@ +from django.db.models import F, OrderBy +from rest_framework.filters import OrderingFilter + + +class NullsLastOrderingFilter(OrderingFilter): + def get_ordering(self, request, queryset, view): + values = super().get_ordering(request, queryset, view) + if not values: + return values + return [OrderBy(F(value.lstrip("-")), descending=value.startswith("-"), nulls_last=True) for value in values] diff --git a/ami/base/models.py b/ami/base/models.py index 25fd4d640..32da1245e 100644 --- a/ami/base/models.py +++ b/ami/base/models.py @@ -21,5 +21,9 @@ def save_async(self, *args, **kwargs): """Save the model in a background task.""" ami.tasks.model_task.delay(self.__class__.__name__, self.pk, "save", *args, **kwargs) + def update_calculated_fields(self, *args, **kwargs): + """Update calculated fields specific to each model.""" + pass + class Meta: abstract = True diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 67703cab4..1f6eba684 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -304,7 +304,7 @@ def enqueue(self): self.finished_at = None self.scheduled_at = datetime.datetime.now() self.status = run_job.AsyncResult(task_id).status - self.save() + self.save(force_update=True) def setup(self, save=True): """ diff --git a/ami/jobs/serializers.py b/ami/jobs/serializers.py index 576f8f9d8..473fd31e7 100644 --- a/ami/jobs/serializers.py +++ b/ami/jobs/serializers.py @@ -84,6 +84,8 @@ class Meta: "details", "name", "delay", + "limit", + "shuffle", "project", "project_id", "deployment", diff --git a/ami/jobs/views.py b/ami/jobs/views.py index 9e0d3c163..ca8bb1595 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -1,3 +1,5 @@ +import logging + from rest_framework.decorators import action from rest_framework.response import Response @@ -7,6 +9,8 @@ from .models import Job from .serializers import JobListSerializer, JobSerializer +logger = logging.getLogger(__name__) + class JobViewSet(DefaultViewSet): """ @@ -89,6 +93,7 @@ def perform_create(self, serializer): """ If the ``start_now`` parameter is passed, enqueue the job immediately. """ + job: Job = serializer.save() # type: ignore if url_boolean_param(self.request, "start_now", default=False): # job.run() diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index 284b73fc5..1e39ad9f3 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -1,6 +1,6 @@ import datetime -from django.db.models import Count, QuerySet +from django.db.models import QuerySet from rest_framework import serializers from ami.base.serializers import DefaultSerializer, get_current_user, reverse_with_params @@ -90,10 +90,6 @@ class DeploymentListSerializer(DefaultSerializer): class Meta: model = Deployment - queryset = Deployment.objects.annotate( - events_count=Count("events"), - occurrences_count=Count("occurrences"), - ) fields = [ "id", "name", @@ -379,6 +375,7 @@ class Meta: "occurrences", "occurrence_images", "last_detected", + "best_determination_score", "created_at", "updated_at", ] @@ -908,15 +905,12 @@ def get_determination_details(self, obj: Occurrence): taxon=taxon, identification=identification, prediction=prediction, - score=obj.determination_score(), + score=obj.determination_score, ) class OccurrenceSerializer(OccurrenceListSerializer): determination = CaptureTaxonSerializer(read_only=True) - determination_id = serializers.PrimaryKeyRelatedField( - write_only=True, queryset=Taxon.objects.all(), source="determination" - ) detections = DetectionNestedSerializer(many=True, read_only=True) identifications = OccurrenceIdentificationSerializer(many=True, read_only=True) predictions = OccurrenceClassificationSerializer(many=True, read_only=True) @@ -932,6 +926,9 @@ class Meta: "identifications", "predictions", ] + read_only_fields = [ + "determination_score", + ] class EventCaptureNestedSerializer(DefaultSerializer): diff --git a/ami/main/api/views.py b/ami/main/api/views.py index 00b07abea..4ad501084 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -11,13 +11,15 @@ from rest_framework import exceptions as api_exceptions from rest_framework import permissions, viewsets from rest_framework.decorators import action -from rest_framework.filters import OrderingFilter, SearchFilter +from rest_framework.filters import SearchFilter from rest_framework.response import Response from rest_framework.views import APIView from ami import tasks +from ami.base.filters import NullsLastOrderingFilter from ..models import ( + DEFAULT_CONFIDENCE_THRESHOLD, Classification, Deployment, Detection, @@ -78,7 +80,7 @@ class DefaultViewSetMixin: filter_backends = [ DjangoFilterBackend, - OrderingFilter, + NullsLastOrderingFilter, SearchFilter, ] filterset_fields = [] @@ -119,16 +121,7 @@ class DeploymentViewSet(DefaultViewSet): for the list and detail views. """ - queryset = Deployment.objects.annotate( - events_count=models.Count("events", distinct=True), - occurrences_count=models.Count("occurrences", distinct=True), - taxa_count=models.Count("occurrences__determination", distinct=True), - captures_count=models.Count("events__captures", distinct=True), - # The first and last date should come from the captures, - # but it may be much slower to query. - first_date=models.Min("events__start__date"), - last_date=models.Max("events__end__date"), - ).select_related("project") + queryset = Deployment.objects.select_related("project") filterset_fields = ["project"] ordering_fields = [ "created_at", @@ -447,6 +440,7 @@ class OccurrenceViewSet(DefaultViewSet): "event", ) .prefetch_related("detections") + .order_by("-determination_score") .all() ) serializer_class = OccurrenceSerializer @@ -459,6 +453,7 @@ class OccurrenceViewSet(DefaultViewSet): "duration", "deployment", "determination", + "determination_score", "event", "detections_count", ] @@ -495,6 +490,7 @@ class TaxonViewSet(DefaultViewSet): "occurrences_count", "detections_count", "last_detected", + "best_determination_score", "name", ] search_fields = ["name", "parent__name"] @@ -557,18 +553,29 @@ def filter_by_occurrence(self, queryset: QuerySet) -> QuerySet: if occurrence_id: occurrence = Occurrence.objects.get(id=occurrence_id) + # This query does not need the same filtering as the others return queryset.filter(occurrences=occurrence).distinct() elif project_id: project = Project.objects.get(id=project_id) - return super().get_queryset().filter(occurrences__project=project).distinct() + queryset = super().get_queryset().filter(occurrences__project=project) elif deployment_id: deployment = Deployment.objects.get(id=deployment_id) - return super().get_queryset().filter(occurrences__deployment=deployment).distinct() + queryset = super().get_queryset().filter(occurrences__deployment=deployment) elif event_id: event = Event.objects.get(id=event_id) - return super().get_queryset().filter(occurrences__event=event).distinct() - else: - return queryset + queryset = super().get_queryset().filter(occurrences__event=event) + + queryset = ( + queryset.annotate(best_determination_score=models.Max("occurrences__determination_score")) + .filter(best_determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD) + .distinct() + ) + + # If ordering is not specified, order by best determination score + if not self.request.query_params.get("ordering"): + queryset = queryset.order_by("-best_determination_score") + + return queryset def get_queryset(self) -> QuerySet: qs = super().get_queryset() @@ -621,10 +628,16 @@ def get(self, request): "events_count": Event.objects.filter(deployment__project=project).count(), "captures_count": SourceImage.objects.filter(deployment__project=project).count(), "detections_count": Detection.objects.filter(occurrence__project=project).count(), - "occurrences_count": Occurrence.objects.filter(project=project).count(), + "occurrences_count": Occurrence.objects.filter( + project=project, + determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD, + ).count(), "taxa_count": Taxon.objects.annotate(occurrences_count=models.Count("occurrences")) - .filter(occurrences_count__gt=0) - .filter(occurrences__project=project) + .filter( + occurrences_count__gt=0, + occurrences__determination_score__gte=DEFAULT_CONFIDENCE_THRESHOLD, + occurrences__project=project, + ) .distinct() .count(), } diff --git a/ami/main/migrations/0024_deployment_captures_count_and_more.py b/ami/main/migrations/0024_deployment_captures_count_and_more.py new file mode 100644 index 000000000..1e4457228 --- /dev/null +++ b/ami/main/migrations/0024_deployment_captures_count_and_more.py @@ -0,0 +1,47 @@ +# Generated by Django 4.2.2 on 2023-12-01 21:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0023_taxon_main_taxon_orderin_4ffb7b_idx"), + ] + + operations = [ + migrations.AddField( + model_name="deployment", + name="captures_count", + field=models.IntegerField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="detections_count", + field=models.IntegerField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="events_count", + field=models.IntegerField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="first_capture_timestamp", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="last_capture_timestamp", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="occurrences_count", + field=models.IntegerField(blank=True, null=True), + ), + migrations.AddField( + model_name="deployment", + name="taxa_count", + field=models.IntegerField(blank=True, null=True), + ), + ] diff --git a/ami/main/migrations/0025_update_deployment_aggregates.py b/ami/main/migrations/0025_update_deployment_aggregates.py new file mode 100644 index 000000000..4055be21e --- /dev/null +++ b/ami/main/migrations/0025_update_deployment_aggregates.py @@ -0,0 +1,27 @@ +# Generated by Django 4.2.2 on 2023-12-01 21:43 + +from django.db import migrations +import logging + +logger = logging.getLogger(__name__) + + +# Save all Deployment objects to update their calculated fields. +def update_deployment_aggregates(apps, schema_editor): + # Deployment = apps.get_model("main", "Deployment") + from ami.main.models import Deployment + + for deployment in Deployment.objects.all(): + logger.info(f"Updating deployment {deployment}") + deployment.save(update_calculated_fields=True) + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0024_deployment_captures_count_and_more"), + ] + + # operations = [] + operations = [ + migrations.RunPython(update_deployment_aggregates, migrations.RunPython.noop), + ] diff --git a/ami/main/migrations/0026_occurrence_determination_score.py b/ami/main/migrations/0026_occurrence_determination_score.py new file mode 100644 index 000000000..038e77a65 --- /dev/null +++ b/ami/main/migrations/0026_occurrence_determination_score.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.2 on 2023-12-02 01:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0025_update_deployment_aggregates"), + ] + + operations = [ + migrations.AddField( + model_name="occurrence", + name="determination_score", + field=models.FloatField(blank=True, null=True), + ), + ] diff --git a/ami/main/migrations/0027_update_occurrence_scores.py b/ami/main/migrations/0027_update_occurrence_scores.py new file mode 100644 index 000000000..59c2ee5d4 --- /dev/null +++ b/ami/main/migrations/0027_update_occurrence_scores.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.2 on 2023-12-02 01:08 + +from django.db import migrations + + +# Call save on all occurrences to update their scores +def update_occurrence_scores(apps, schema_editor): + # Occurrence = apps.get_model("main", "Occurrence") + from ami.main.models import Occurrence + + for occurrence in Occurrence.objects.all(): + occurrence.save() + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0026_occurrence_determination_score"), + ] + + operations = [ + migrations.RunPython(update_occurrence_scores, migrations.RunPython.noop), + ] diff --git a/ami/main/models.py b/ami/main/models.py index f4549db8a..5448e1219 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -25,6 +25,8 @@ from ami.users.models import User from ami.utils.schemas import OrderedEnum +logger = logging.getLogger(__name__) + # Constants _POST_TITLE_MAX_LENGTH: Final = 80 @@ -52,15 +54,29 @@ class TaxonRank(OrderedEnum): ] ) +DEFAULT_CONFIDENCE_THRESHOLD = 0.29 + # @TODO move to settings & make configurable _SOURCE_IMAGES_URL_BASE = "https://static.dev.insectai.org/ami-trapdata/vermont/snapshots/" _CROPS_URL_BASE = "https://static.dev.insectai.org/ami-trapdata/crops" -as_choices = lambda x: [(i, i) for i in x] # noqa: E731 +def get_media_url(path: str) -> str: + """ + If path is a full URL, return it as-is. + Otherwise, join it with the MEDIA_URL setting. + """ + # @TODO use settings + # urllib.parse.urljoin(settings.MEDIA_URL, self.path) + if path.startswith("http"): + url = path + else: + url = urllib.parse.urljoin(_CROPS_URL_BASE, path.lstrip("/")) + return url -logger = logging.getLogger(__name__) + +as_choices = lambda x: [(i, i) for i in x] # noqa: E731 @final @@ -165,15 +181,8 @@ class DeploymentManager(models.Manager): def get_queryset(self): return ( - super() - .get_queryset() - .annotate( - events_count=models.Count("events"), - # These are very slow as the numbers increase (1M captures) - # occurrences_count=models.Count("occurrences"), - # captures_count=models.Count("captures"), - # detections_count=models.Count("captures__detections") - ) + super().get_queryset() + # Add any common annotations or optimizations here ) @@ -248,11 +257,18 @@ class Deployment(BaseModel): name = models.CharField(max_length=_POST_TITLE_MAX_LENGTH) description = models.TextField(blank=True) + latitude = models.FloatField(null=True, blank=True) + longitude = models.FloatField(null=True, blank=True) + image = models.ImageField(upload_to="deployments", blank=True, null=True) + + project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="deployments") # @TODO consider sharing only the "data source auth/config" then a one-to-one config for each deployment data_source = models.ForeignKey( "S3StorageSource", on_delete=models.SET_NULL, null=True, blank=True, related_name="deployments" ) + + # Precalculated values from the data source data_source_total_files = models.IntegerField(blank=True, null=True) data_source_total_size = models.BigIntegerField(blank=True, null=True) data_source_subdir = models.CharField(max_length=255, blank=True, null=True) @@ -264,11 +280,14 @@ class Deployment(BaseModel): # data_source_last_check_status = models.CharField(max_length=255, blank=True, null=True) # data_source_last_check_notes = models.TextField(max_length=255, blank=True, null=True) - latitude = models.FloatField(null=True, blank=True) - longitude = models.FloatField(null=True, blank=True) - image = models.ImageField(upload_to="deployments", blank=True, null=True) - - project = models.ForeignKey(Project, on_delete=models.SET_NULL, null=True, related_name="deployments") + # Precaclulated values + events_count = models.IntegerField(blank=True, null=True) + occurrences_count = models.IntegerField(blank=True, null=True) + captures_count = models.IntegerField(blank=True, null=True) + detections_count = models.IntegerField(blank=True, null=True) + taxa_count = models.IntegerField(blank=True, null=True) + first_capture_timestamp = models.DateTimeField(blank=True, null=True) + last_capture_timestamp = models.DateTimeField(blank=True, null=True) research_site = models.ForeignKey( Site, @@ -288,29 +307,9 @@ class Deployment(BaseModel): class Meta: ordering = ["name"] - def events_count(self) -> int | None: - # return self.events.count() - # Uses the annotated value from the custom manager - return None - - def captures_count(self) -> int: - return self.data_source_total_files or 0 - - def detections_count(self) -> int: - return Detection.objects.filter(Q(source_image__deployment=self)).count() - # return None - - def occurrences_count(self) -> int: - return self.occurrences.count() - # return None - def taxa(self) -> models.QuerySet["Taxon"]: return Taxon.objects.filter(Q(occurrences__deployment=self)).distinct() - def taxa_count(self) -> int | None: - return self.taxa().count() - # return None - def example_captures(self, num=10) -> models.QuerySet["SourceImage"]: return SourceImage.objects.filter(deployment=self).order_by("-size")[:num] @@ -321,10 +320,9 @@ def first_capture(self) -> typing.Optional["SourceImage"]: return SourceImage.objects.filter(deployment=self).order_by("timestamp").first() def last_capture(self) -> typing.Optional["SourceImage"]: - return SourceImage.objects.filter(deployment=self).order_by("-timestamp").first() + return SourceImage.objects.filter(deployment=self).order_by("timestamp").last() - @functools.cached_property - def first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.datetime]: + def get_first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.datetime]: # Retrieve the timestamps of the first and last capture in a single query first, last = ( SourceImage.objects.filter(deployment=self) @@ -334,14 +332,10 @@ def first_and_last_timestamps(self) -> tuple[datetime.datetime, datetime.datetim return (first, last) def first_date(self) -> datetime.date | None: - date, _ = self.first_and_last_timestamps - if date: - return date.date() + return self.first_capture_timestamp.date() if self.first_capture_timestamp else None def last_date(self) -> datetime.date | None: - _, date = self.first_and_last_timestamps - if date: - return date.date() + return self.last_capture_timestamp.date() if self.last_capture_timestamp else None def data_source_uri(self) -> str | None: if self.data_source: @@ -408,7 +402,7 @@ def update_children(self): ] for model_name in child_models: model = apps.get_model("main", model_name) - project_values = model.objects.filter(deployment=self).values_list("project", flat=True).distinct() + project_values = set(model.objects.filter(deployment=self).values_list("project", flat=True).distinct()) if len(project_values) > 1: logger.warning( f"Deployment {self} has alternate projects set on {model_name} " @@ -422,19 +416,27 @@ def update_calculated_fields(self, save=False): self.data_source_total_files = self.captures.count() self.data_source_total_size = self.captures.aggregate(total_size=models.Sum("size")).get("total_size") + self.events_count = self.events.count() + self.captures_count = self.data_source_total_files or self.captures.count() + self.detections_count = Detection.objects.filter(Q(source_image__deployment=self)).count() + self.occurrences_count = self.occurrences.count() + self.taxa_count = Taxon.objects.filter(Q(occurrences__deployment=self)).distinct().count() + + self.first_capture_timestamp, self.last_capture_timestamp = self.get_first_and_last_timestamps() + if save: - self.save() + self.save(update_calculated_fields=False) - def save(self, *args, update_calculated_fields=True, **kwargs): + def save(self, update_calculated_fields=True, *args, **kwargs): + super().save(*args, **kwargs) if self.pk and update_calculated_fields: - self.update_calculated_fields() + self.update_calculated_fields(save=True) if self.project: self.update_children() # @TODO this isn't working as a background task # ami.tasks.model_task.delay("Project", self.project.pk, "update_children_project") # @TODO Use "dirty" flag strategy to only update when needed ami.tasks.regroup_events.delay(self.pk) - super().save(*args, **kwargs) @final @@ -566,7 +568,7 @@ def summary_data(self): return plots - def update_calculated_fields(self): + def update_calculated_fields(self, save=False): if not self.group_by and self.start: # If no group_by is set, use the start "day" self.group_by = self.start.date() @@ -583,9 +585,13 @@ def update_calculated_fields(self): if last: self.end = last["timestamp"] - def save(self, *args, **kwargs): - self.update_calculated_fields() + if save: + self.save(update_calculated_fields=False) + + def save(self, update_calculated_fields=True, *args, **kwargs): super().save(*args, **kwargs) + if update_calculated_fields: + self.update_calculated_fields(save=True) def group_images_into_events( @@ -935,7 +941,7 @@ def get_dimensions(self) -> tuple[int | None, int | None]: return self.width, self.height return None, None - def update_calculated_fields(self): + def update_calculated_fields(self, save=False): if self.path and not self.timestamp: self.timestamp = self.extract_timestamp() if self.path and not self.public_base_url: @@ -944,10 +950,13 @@ def update_calculated_fields(self): self.project = self.deployment.project if self.pk is not None: self.detections_count = self.get_detections_count() + if save: + self.save(update_calculated_fields=False) - def save(self, *args, **kwargs): - self.update_calculated_fields() + def save(self, update_calculated_fields=True, *args, **kwargs): super().save(*args, **kwargs) + if update_calculated_fields: + self.update_calculated_fields(save=True) class Meta: ordering = ("deployment", "event", "timestamp") @@ -1185,6 +1194,7 @@ class Identification(BaseModel): blank=True, related_name="agreed_identifications", ) + score = 1.0 # Always 1 for humans, at this time class Meta: ordering = [ @@ -1395,39 +1405,24 @@ def best_classification(self): else: return (None, None) - def url(self): - # @TODO use settings - # urllib.parse.urljoin(settings.MEDIA_URL, self.path) - logger.info(f"DETECTION URL: {self.path}") - print(f"DETECTION URL: {self.path}") - if self.path.startswith("http"): - url = self.path - else: - url = urllib.parse.urljoin(_CROPS_URL_BASE, self.path.lstrip("/")) - logger.info(f"DETECTION URL: {url}") - return url + def url(self) -> str | None: + return get_media_url(self.path) if self.path else None - def associate_new_occurrence(self): + def associate_new_occurrence(self) -> "Occurrence": """ Create and associate a new occurrence with this detection. """ if self.occurrence: return self.occurrence - classifications = self.classifications.first() - if classifications: - taxon = classifications.taxon - else: - taxon = None - occurrence = Occurrence( + occurrence = Occurrence.objects.create( event=self.source_image.event, deployment=self.source_image.deployment, project=self.source_image.project, - determination=taxon, ) - occurrence.save() self.occurrence = occurrence self.save() + occurrence.save() # Need to save again to update the aggregate values # Update aggregate values on source image # @TODO this should be done async in a task with an eta of a few seconds # so it isn't done for every detection in a batch @@ -1449,6 +1444,7 @@ class Occurrence(BaseModel): # @TODO change Determination to a nested field with a Taxon, User, Identification, etc like the serializer # this could be a OneToOneField to a Determination model or a JSONField validated by a Pydantic model determination = models.ForeignKey("Taxon", on_delete=models.SET_NULL, null=True, related_name="occurrences") + determination_score = models.FloatField(null=True, blank=True) event = models.ForeignKey(Event, on_delete=models.SET_NULL, null=True, related_name="occurrences") deployment = models.ForeignKey(Deployment, on_delete=models.SET_NULL, null=True, related_name="occurrences") @@ -1481,7 +1477,7 @@ def first_appearance(self) -> SourceImage | None: @functools.cached_property def last_appearance(self) -> SourceImage | None: # @TODO it appears we only need the last timestamp, that could be an annotated value - last = self.detections.order_by("-timestamp").select_related("source_image").first() + last = self.detections.order_by("timestamp").select_related("source_image").last() if last: return last.source_image @@ -1524,43 +1520,29 @@ def best_prediction(self): def best_identification(self): return Identification.objects.filter(occurrence=self, withdrawn=False).order_by("-created_at").first() - def get_determination_score(self) -> float: - logger.warning( - f"Calculating determination score for Occurrence #{self.pk} " - "(this should be come from a query annotation and be cached)" - ) + def get_determination_score(self) -> float | None: if not self.determination: - return 0 + return None elif self.best_identification: - # If the occurrence has been verified by humans, then consider determination 100% certain - return 1.0 + return self.best_identification.score + elif self.best_prediction: + return self.best_prediction.score else: - return Classification.objects.filter(detection__occurrence=self).aggregate(models.Max("score"))[ - "score__max" - ] - - def determination_score(self) -> float: - """ - Example, get best determination score for each occurrence if it has no identifications: - - If score was populated by a query annotation, use that - otherwise call the get() method to calculate it. - """ - if hasattr(self, "determination_score") and isinstance(self.determination_score, float): - score = self.determination_score - else: - score = self.get_determination_score() - return score + return None def predictions(self): # Retrieve the classification with the max score for each algorithm - classifications = Classification.objects.filter(detection__occurrence=self).filter( - score__in=models.Subquery( - Classification.objects.filter(detection__occurrence=self) - .values("algorithm") - .annotate(max_score=models.Max("score")) - .values("max_score") + classifications = ( + Classification.objects.filter(detection__occurrence=self) + .filter( + score__in=models.Subquery( + Classification.objects.filter(detection__occurrence=self) + .values("algorithm") + .annotate(max_score=models.Max("score")) + .values("max_score") + ) ) + .order_by("-created_at") ) return classifications @@ -1576,8 +1558,22 @@ def url(self): # @TODO this was a temporary hack. Use settings and reverse(). return f"https://app.preview.insectai.org/occurrences/{self.pk}" + def save(self, update_determination=True, *args, **kwargs): + super().save(*args, **kwargs) + if update_determination: + update_occurrence_determination( + self, + current_determination=self.determination, + save=True, + ) -def update_occurrence_determination(occurrence: Occurrence, current_determination: typing.Optional["Taxon"] = None): + class Meta: + ordering = ["-determination_score"] + + +def update_occurrence_determination( + occurrence: Occurrence, current_determination: typing.Optional["Taxon"] = None, save=True +): """ Update the determination of the occurrence based on the identifications & predictions. @@ -1591,6 +1587,8 @@ def update_occurrence_determination(occurrence: Occurrence, current_determinatio @TODO Add tests for this important method! """ + needs_update = False + current_determination = ( current_determination or Occurrence.objects.select_related("determination") @@ -1598,19 +1596,30 @@ def update_occurrence_determination(occurrence: Occurrence, current_determinatio .get(pk=occurrence.pk)["determination"] ) new_determination = None + new_score = None top_identification = occurrence.best_identification if top_identification and top_identification.taxon and top_identification.taxon != current_determination: new_determination = top_identification.taxon + new_score = top_identification.score elif not top_identification: top_prediction = occurrence.best_prediction if top_prediction and top_prediction.taxon and top_prediction.taxon != current_determination: new_determination = top_prediction.taxon + new_score = top_prediction.score if new_determination and new_determination != current_determination: - logger.info(f"Changing determination of {occurrence} from {current_determination} to {new_determination}") + logger.info(f"Changing det. of {occurrence} from {current_determination} to {new_determination}") occurrence.determination = new_determination - occurrence.save() + needs_update = True + + if new_score and new_score != occurrence.determination_score: + logger.info(f"Changing det. score of {occurrence} from {occurrence.determination_score} to {new_score}") + occurrence.determination_score = new_score + needs_update = True + + if save and needs_update: + occurrence.save(update_determination=False) @final @@ -1818,17 +1827,28 @@ def last_detected(self) -> datetime.datetime | None: # This is handled by an annotation return None - def occurrence_images(self): + def best_determination_score(self) -> float | None: + # This is handled by an annotation if we are filtering by project, deployment or event + return None + + def occurrence_images(self, limit: int | None = 10) -> list[str]: """ Return one image from each occurrence of this Taxon. The image should be from the detection with the highest classification score. + + This is used for image thumbnail previews in the species summary view. """ - # @TODO Can we use a single query - for occurrence in self.occurrences.prefetch_related("detections__classifications").all(): - detection = occurrence.detections.order_by("-classifications__score").first() - if detection: - yield detection.url() + # Retrieve the URLs using a single optimized query + detection_image_paths = ( + self.occurrences.prefetch_related("detections__classifications") + .annotate(max_score=models.Max("detections__classifications__score")) + .filter(detections__classifications__score=models.F("max_score")) + .order_by("-max_score") + .values_list("detections__path", flat=True)[:limit] + ) + + return [get_media_url(path) for path in detection_image_paths if path] def list_names(self) -> str: return ", ".join(self.lists.values_list("name", flat=True)) diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index 64712f245..ff9beb08e 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -200,7 +200,7 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m new_classification.detection = detection new_classification.taxon = taxon new_classification.algorithm = algo - new_classification.score = classification.scores[0] + new_classification.score = max(classification.scores) new_classification.timestamp = now() # @TODO get timestamp from API response # @TODO add reference to job or pipeline? @@ -214,9 +214,11 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m deployment=source_image.deployment, project=source_image.project, determination=taxon, + determination_score=new_classification.score, ) detection.occurrence = occurrence detection.save() + detection.occurrence.save() # Update precalculated counts on source images for source_image in source_images: diff --git a/ami/utils/dates.py b/ami/utils/dates.py index 49146ebee..e89d358bc 100644 --- a/ami/utils/dates.py +++ b/ami/utils/dates.py @@ -36,8 +36,12 @@ def get_image_timestamp_from_filename(img_path, raise_error=False) -> datetime.d # Extract date from a filename using regex in the format %Y%m%d%H%M%S matches = re.search(r"(\d{14})", name) if matches: - date = datetime.datetime.strptime(matches.group(), "%Y%m%d%H%M%S") - else: + try: + date = datetime.datetime.strptime(matches.group(), "%Y%m%d%H%M%S") + except ValueError: + pass + + if not date: try: date = dateutil.parser.parse(name, fuzzy=False) # Fuzzy will interpret "DSC_1974" as 1974-01-01 except dateutil.parser.ParserError: diff --git a/ui/src/data-services/models/occurrence-details.ts b/ui/src/data-services/models/occurrence-details.ts index 503c68e47..fcdd0cd5f 100644 --- a/ui/src/data-services/models/occurrence-details.ts +++ b/ui/src/data-services/models/occurrence-details.ts @@ -100,9 +100,13 @@ export class OccurrenceDetails extends Occurrence { ) const classification = detection?.classifications?.[0] + let label = 'No classification' - if (!classification) { - return + if (classification) { + label = `${classification.taxon.name} (${_.round( + classification.score, + 4 + )})` } return { @@ -116,10 +120,7 @@ export class OccurrenceDetails extends Occurrence { width: detection.width, height: detection.height, }, - label: `${classification.taxon.name} (${_.round( - classification.score, - 4 - )})`, + label: label, timeLabel: getFormatedTimeString({ date: new Date(detection.timestamp), }), diff --git a/ui/src/data-services/models/species.ts b/ui/src/data-services/models/species.ts index f2c6c135c..0346839ab 100644 --- a/ui/src/data-services/models/species.ts +++ b/ui/src/data-services/models/species.ts @@ -22,11 +22,11 @@ export class Species extends Taxon { } get numDetections(): number { - return this._species.detections_count + return this._species.detections_count || null } get numOccurrences(): number { - return this._species.occurrences_count + return this._species.occurrences_count || null } get trainingImagesLabel(): string { @@ -36,4 +36,8 @@ export class Species extends Taxon { get trainingImagesUrl(): string { return `https://www.gbif.org/occurrence/gallery?advanced=1&verbatim_scientific_name=${this.name}` } + + get score(): number { + return this._species.best_determination_score + } } diff --git a/ui/src/design-system/components/info-block/info-block.tsx b/ui/src/design-system/components/info-block/info-block.tsx index 7f39db162..65f98da0e 100644 --- a/ui/src/design-system/components/info-block/info-block.tsx +++ b/ui/src/design-system/components/info-block/info-block.tsx @@ -1,5 +1,6 @@ import classNames from 'classnames' import { Link } from 'react-router-dom' +import { STRING } from 'utils/language' import styles from './info-block.module.scss' interface Field { @@ -11,7 +12,7 @@ interface Field { export const InfoBlock = ({ fields }: { fields: Field[] }) => ( <> {fields.map((field, index) => { - const value = field.value !== undefined ? field.value : 'N/A' + const value = field.value !== undefined ? field.value : STRING.VALUE_NOT_AVAILABLE return (

diff --git a/ui/src/pages/occurrences/occurrence-columns.tsx b/ui/src/pages/occurrences/occurrence-columns.tsx index 09556c30a..4c0ba9963 100644 --- a/ui/src/pages/occurrences/occurrence-columns.tsx +++ b/ui/src/pages/occurrences/occurrence-columns.tsx @@ -9,7 +9,7 @@ import { ImageTableCell } from 'design-system/components/table/image-table-cell/ import { CellTheme, ImageCellTheme, - TableColumn, + TableColumn } from 'design-system/components/table/types' import { Tooltip } from 'design-system/components/tooltip/tooltip' import { Agree } from 'pages/occurrence-details/agree/agree' @@ -81,14 +81,22 @@ export const columns: (projectId: string) => TableColumn[] = ( { id: 'date', name: translate(STRING.FIELD_LABEL_DATE), - sortField: 'event__start', - renderCell: (item: Occurrence) => , + sortField: 'first_appearance_time', + renderCell: (item: Occurrence) => ( + + + + ), }, { id: 'time', sortField: 'first_appearance_time', name: translate(STRING.FIELD_LABEL_TIME), - renderCell: (item: Occurrence) => , + renderCell: (item: Occurrence) => ( + + + + ) }, { id: 'duration', @@ -106,6 +114,15 @@ export const columns: (projectId: string) => TableColumn[] = ( ), }, + { + id: 'score', + name: translate(STRING.FIELD_LABEL_BEST_SCORE), + sortField: 'determination_score', + renderCell: (item: Occurrence) => ( + // This should always appear as a float with 2 decimal places, even if 1.00 + + ), + }, ] const TaxonCell = ({ diff --git a/ui/src/pages/occurrences/occurrences.tsx b/ui/src/pages/occurrences/occurrences.tsx index 6b8e66453..1c268e395 100644 --- a/ui/src/pages/occurrences/occurrences.tsx +++ b/ui/src/pages/occurrences/occurrences.tsx @@ -30,10 +30,11 @@ export const Occurrences = () => { }>({ snapshots: true, id: true, - deployment: true, - session: true, + date: true, + time: true, duration: true, detections: true, + score: true, }) const [sort, setSort] = useState() const { pagination, setPage } = usePagination() diff --git a/ui/src/pages/session-details/playback/capture-job/process-now.tsx b/ui/src/pages/session-details/playback/capture-job/process-now.tsx index ba471ebf7..c79a31d42 100644 --- a/ui/src/pages/session-details/playback/capture-job/process-now.tsx +++ b/ui/src/pages/session-details/playback/capture-job/process-now.tsx @@ -39,7 +39,7 @@ export const ProcessNow = ({ theme={ButtonTheme.Neutral} onClick={() => { createJob({ - delay: 1, + delay: 0, name: `Capture #${capture.id}`, sourceImage: capture.id, pipeline: pipelineId, diff --git a/ui/src/pages/session-details/playback/playback-controls/pipelines-picker.tsx b/ui/src/pages/session-details/playback/playback-controls/pipelines-picker.tsx index 66f02c238..9b1d5854f 100644 --- a/ui/src/pages/session-details/playback/playback-controls/pipelines-picker.tsx +++ b/ui/src/pages/session-details/playback/playback-controls/pipelines-picker.tsx @@ -21,7 +21,7 @@ export const PipelinesPicker = ({ value: p.id, label: p.name, }))} - placeholder="Pick a pipeline" + placeholder="Pipeline" showClear={false} theme={SelectTheme.NeutralCompact} value={value} diff --git a/ui/src/pages/species-details/species-details.tsx b/ui/src/pages/species-details/species-details.tsx index 939164aef..f753a0c37 100644 --- a/ui/src/pages/species-details/species-details.tsx +++ b/ui/src/pages/species-details/species-details.tsx @@ -53,7 +53,7 @@ export const SpeciesDetails = ({ species }: { species: Species }) => { value: species.trainingImagesLabel, to: species.trainingImagesUrl, }, - ] + ].filter((field) => field.value !== null); return (

diff --git a/ui/src/pages/species/species-columns.tsx b/ui/src/pages/species/species-columns.tsx index da1ce6ed6..fe6e103ca 100644 --- a/ui/src/pages/species/species-columns.tsx +++ b/ui/src/pages/species/species-columns.tsx @@ -48,17 +48,6 @@ export const columns: (projectId: string) => TableColumn[] = ( ), }, - { - id: 'detections', - sortField: 'detections_count', - name: translate(STRING.FIELD_LABEL_DETECTIONS), - styles: { - textAlign: TextAlign.Right, - }, - renderCell: (item: Species) => ( - - ), - }, { id: 'occurrences', sortField: 'occurrences_count', @@ -77,6 +66,17 @@ export const columns: (projectId: string) => TableColumn[] = ( ), }, + { + id: 'score', + sortField: 'best_determination_score', + name: translate(STRING.FIELD_LABEL_BEST_SCORE), + styles: { + textAlign: TextAlign.Right, + }, + renderCell: (item: Species) => ( + + ), + }, { id: 'training-images', name: translate(STRING.FIELD_LABEL_TRAINING_IMAGES), diff --git a/ui/src/utils/language.ts b/ui/src/utils/language.ts index 5b96e0a38..0dbba7b18 100644 --- a/ui/src/utils/language.ts +++ b/ui/src/utils/language.ts @@ -46,6 +46,7 @@ export enum STRING { /* FIELD_LABEL */ FIELD_LABEL_AVG_TEMP, + FIELD_LABEL_BEST_SCORE, FIELD_LABEL_CAPTURES, FIELD_LABEL_COMMENT, FIELD_LABEL_CONNECTION_STATUS, @@ -163,6 +164,7 @@ export enum STRING { UPDATING_DATA, USER_INFO, VERIFIED_BY, + VALUE_NOT_AVAILABLE, } const ENGLISH_STRINGS: { [key in STRING]: string } = { @@ -200,6 +202,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { /* FIELD_LABEL */ [STRING.FIELD_LABEL_AVG_TEMP]: 'Avg temp', + [STRING.FIELD_LABEL_BEST_SCORE]: 'Score', [STRING.FIELD_LABEL_CAPTURES]: 'Captures', [STRING.FIELD_LABEL_COMMENT]: 'Comment', [STRING.FIELD_LABEL_CONNECTION_STATUS]: 'Connection status', @@ -240,7 +243,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { [STRING.FIELD_LABEL_THUMBNAIL]: 'Thumbnail', [STRING.FIELD_LABEL_TIME]: 'Time', [STRING.FIELD_LABEL_TIMESTAMP]: 'Timestamp', - [STRING.FIELD_LABEL_TRAINING_IMAGES]: 'Training images', + [STRING.FIELD_LABEL_TRAINING_IMAGES]: 'Reference images', [STRING.FIELD_LABEL_FIRST_DATE]: 'First date', [STRING.FIELD_LABEL_LAST_DATE]: 'Last date', [STRING.FIELD_LABEL_UPDATED_AT]: 'Updated at', @@ -340,6 +343,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { [STRING.UPDATING_DATA]: 'Updating data', [STRING.USER_INFO]: 'User info', [STRING.VERIFIED_BY]: 'Verified by\n{{name}}', + [STRING.VALUE_NOT_AVAILABLE]: 'n/a', } // When we have more translations available, this function could return a value based on current language settings.