Skip to content

Commit

Permalink
Fix incorrect dates of occurrences & detections (#360)
Browse files Browse the repository at this point in the history
* Fix Detection timestamps that do not use the date they were "captured"

* Allow sorting by species / determination name

* It appears we don't need first & last detection objects, speed up query

* Use timestamps from query annotation instead of join, sort by timestamps

* Add some methods for troubleshooting bad event/session groupings
  • Loading branch information
mihow authored Feb 22, 2024
1 parent a355c8c commit d37c2cb
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 42 deletions.
11 changes: 10 additions & 1 deletion ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,16 @@ def sync_captures(self, request: HttpRequest, queryset: QuerySet[Deployment]) ->
msg = f"Syncing captures for {len(queued_tasks)} deployments in background: {queued_tasks}"
self.message_user(request, msg)

actions = [sync_captures]
# Action that regroups all captures in the deployment into events
@admin.action(description="Regroup captures into events")
def regroup_events(self, request: HttpRequest, queryset: QuerySet[Deployment]) -> None:
from ami.main.models import group_images_into_events

for deployment in queryset:
group_images_into_events(deployment)
self.message_user(request, f"Regrouped {queryset.count()} deployments.")

actions = [sync_captures, regroup_events]

def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
qs = super().get_queryset(request)
Expand Down
23 changes: 14 additions & 9 deletions ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,8 @@ class TaxonOccurrenceNestedSerializer(DefaultSerializer):
event = EventNestedSerializer(read_only=True)
best_detection = TaxonDetectionsSerializer(read_only=True)
determination = CaptureTaxonSerializer(read_only=True)
first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
last_appearance = TaxonSourceImageNestedSerializer(read_only=True)
# first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
# last_appearance = TaxonSourceImageNestedSerializer(read_only=True)

class Meta:
model = Occurrence
Expand All @@ -619,8 +619,10 @@ class Meta:
"detections_count",
"duration",
"duration_label",
"first_appearance",
"last_appearance",
"first_appearance_timestamp",
"last_appearance_timestamp",
# "first_appearance",
# "last_appearance",
]


Expand Down Expand Up @@ -944,21 +946,24 @@ class OccurrenceListSerializer(DefaultSerializer):
determination = CaptureTaxonSerializer(read_only=True)
deployment = DeploymentNestedSerializer(read_only=True)
event = EventNestedSerializer(read_only=True)
first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
# first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
determination_details = serializers.SerializerMethodField()

class Meta:
model = Occurrence
# queryset = Occurrence.objects.annotate(
# determination_score=Max("detections__classsifications__score")
# determination_score=Max("detections__classifications__score")
# )
fields = [
"id",
"details",
"event",
"deployment",
"first_appearance",
# So far, we don't need the whole related object, just the timestamps
# "first_appearance",
"first_appearance_timestamp",
# need both timestamp and time for sorting at the database level
# (want to see all moths that occur after 3am, regardless of the date)
"first_appearance_time",
"duration",
"duration_label",
Expand Down Expand Up @@ -1005,7 +1010,7 @@ class OccurrenceSerializer(OccurrenceListSerializer):
predictions = OccurrenceClassificationSerializer(many=True, read_only=True)
deployment = DeploymentNestedSerializer(read_only=True)
event = EventNestedSerializer(read_only=True)
first_appearance = TaxonSourceImageNestedSerializer(read_only=True)
# first_appearance = TaxonSourceImageNestedSerializer(read_only=True)

class Meta:
model = Occurrence
Expand Down Expand Up @@ -1101,7 +1106,7 @@ def get_captures(self, obj):

def get_capture_page_offset(self, obj) -> int | None:
"""
Look up the source image (capture) that contains a specfic detection or occurrence.
Look up the source image (capture) that contains a specific detection or occurrence.
Return the page offset for the capture to be used when requesting the capture list endpoint.
"""
Expand Down
36 changes: 24 additions & 12 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class OccurrenceViewSet(DefaultViewSet):
queryset = Occurrence.objects.all()

serializer_class = OccurrenceSerializer
filterset_fields = ["event", "deployment", "determination", "project"]
filterset_fields = ["event", "deployment", "determination", "project", "determination__rank"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -471,6 +471,7 @@ class OccurrenceViewSet(DefaultViewSet):
"duration",
"deployment",
"determination",
"determination__name",
"determination_score",
"event",
"detections_count",
Expand All @@ -492,24 +493,27 @@ def get_queryset(self) -> QuerySet:
"determination",
"deployment",
"event",
).annotate(
detections_count=models.Count("detections", distinct=True),
duration=models.Max("detections__timestamp") - models.Min("detections__timestamp"),
first_appearance_timestamp=models.Min("detections__timestamp"),
first_appearance_time=models.Min("detections__timestamp__time"),
)
if self.action == "list":
qs = (
qs.all()
.exclude(detections=None)
.exclude(event=None)
.filter(determination_score__gte=get_active_classification_threshold(self.request))
.annotate(
detections_count=models.Count("detections", distinct=True),
duration=models.Max("detections__timestamp") - models.Min("detections__timestamp"),
first_appearance_timestamp=models.Min("detections__timestamp"),
first_appearance_time=models.Min("detections__timestamp__time"),
)
.exclude(first_appearance_timestamp=None) # This must come after annotations
.order_by("-determination_score")
)
else:
qs = qs.prefetch_related("detections", "detections__source_image")
qs = qs.prefetch_related(
Prefetch(
"detections", queryset=Detection.objects.order_by("-timestamp").select_related("source_image")
)
)

return qs

Expand Down Expand Up @@ -650,10 +654,18 @@ def get_queryset(self) -> QuerySet:

# @TODO this should check what the user has access to
project_id = self.request.query_params.get("project")
taxon_occurrences_query = Occurrence.objects.filter(
determination_score__gte=get_active_classification_threshold(self.request),
event__isnull=False,
).distinct()
taxon_occurrences_query = (
Occurrence.objects.filter(
determination_score__gte=get_active_classification_threshold(self.request),
event__isnull=False,
)
.distinct()
.annotate(
first_appearance_timestamp=models.Min("detections__timestamp"),
last_appearance_timestamp=models.Max("detections__timestamp"),
)
.order_by("-first_appearance_timestamp")
)
taxon_occurrences_count_filter = models.Q(
occurrences__determination_score__gte=get_active_classification_threshold(self.request),
occurrences__event__isnull=False,
Expand Down
39 changes: 39 additions & 0 deletions ami/main/management/commands/fix_timestamps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import logging

from django.core.management.base import BaseCommand, CommandError # noqa
from django.db.models import OuterRef, Subquery

from ...models import Detection, SourceImage

logger = logging.getLogger(__name__)


def fix_detection_timestamps(dry_run=True) -> str:
# Subquery to get the timestamp from the related SourceImage
source_image_timestamp_subquery = SourceImage.objects.filter(id=OuterRef("source_image_id")).values("timestamp")[
:1
]

if dry_run:
# Count all Detection objects where timestamp does not match their SourceImage.timestamp
count = Detection.objects.exclude(timestamp=Subquery(source_image_timestamp_subquery)).count()
return f"Would update {count} Detection objects where timestamp does not match their SourceImage.timestamp"

# Update all Detection objects where timestamp does not match their SourceImage.timestamp
updated = Detection.objects.exclude(timestamp=Subquery(source_image_timestamp_subquery)).update(
timestamp=Subquery(source_image_timestamp_subquery)
)
return f"Updated {updated} Detection objects where timestamp does not match their SourceImage.timestamp"


class Command(BaseCommand):
r"""Audit and fix timestamps on Detection objects."""

help = "Audit and fix timestamps on Detection objects"

def add_arguments(self, parser):
parser.add_argument("--dry-run", action="store_true", help="Do not make any changes")

def handle(self, *args, **options):
msg = fix_detection_timestamps(dry_run=options["dry_run"])
self.stdout.write(msg)
45 changes: 43 additions & 2 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ def group_images_into_events(
)

timestamp_groups = ami.utils.dates.group_datetimes_by_gap(image_timestamps, max_time_gap)
# @TODO this event grouping needs testing. Still getting events over 24 hours
# timestamp_groups = ami.utils.dates.group_datetimes_by_shifted_day(image_timestamps)

events = []
for group in timestamp_groups:
Expand Down Expand Up @@ -679,11 +681,27 @@ def group_images_into_events(
events.append(event)
SourceImage.objects.filter(deployment=deployment, timestamp__in=group).update(event=event)
event.save() # Update start and end times and other cached fields
logger.info(f"Created/updated event {event} with {len(group)} images for deployment {deployment}.")
logger.info(
f"Created/updated event {event} with {len(group)} images for deployment {deployment}. "
f"Duration: {event.duration_label()}"
)

if delete_empty:
delete_empty_events()

events_over_24_hours = Event.objects.filter(
deployment=deployment, start__lt=models.F("end") - datetime.timedelta(days=1)
)
if events_over_24_hours.count():
logger.warning(f"Found {events_over_24_hours.count()} events over 24 hours in deployment {deployment}. ")
events_starting_before_noon = Event.objects.filter(
deployment=deployment, start__lt=models.F("start") + datetime.timedelta(hours=12)
)
if events_starting_before_noon.count():
logger.warning(
f"Found {events_starting_before_noon.count()} events starting before noon in deployment {deployment}. "
)

return events


Expand Down Expand Up @@ -1372,6 +1390,7 @@ class Detection(BaseModel):
null=True,
blank=True,
)
# Time that the detection was created by the algorithm in the ML backend
detection_time = models.DateTimeField(null=True, blank=True)
# @TODO not sure if this detection score is ever used
# I think it was intended to be the score of the detection algorithm (bbox score)
Expand Down Expand Up @@ -1460,6 +1479,21 @@ def associate_new_occurrence(self) -> "Occurrence":
self.source_image.save()
return occurrence

def update_calculated_fields(self, save=True):
needs_update = False
if not self.timestamp:
self.timestamp = self.source_image.timestamp
needs_update = True
if save and needs_update:
self.save(update_calculated_fields=False)

def save(self, update_calculated_fields=True, *args, **kwargs):
super().save(*args, **kwargs)
if self.pk and update_calculated_fields:
self.update_calculated_fields(save=True)
# if not self.occurrence:
# self.associate_new_occurrence()


@final
class OccurrenceManager(models.Manager):
Expand Down Expand Up @@ -1495,7 +1529,7 @@ def __str__(self) -> str:
return name

def detections_count(self) -> int | None:
# Annotaions don't seem to work with nested serializers
# Annotations don't seem to work with nested serializers
return self.detections.count()

@functools.cached_property
Expand Down Expand Up @@ -1526,6 +1560,13 @@ def first_appearance_time(self) -> datetime.time | None:
"""
return None

def last_appearance_timestamp(self) -> datetime.datetime | None:
"""
Return the timestamp of the last appearance.
ONLY if it has been added with a query annotation.
"""
return None

def duration(self) -> datetime.timedelta | None:
first = self.first_appearance
last = self.last_appearance
Expand Down
7 changes: 4 additions & 3 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
source_images = set()

for detection in results.detections:
# @TODO use bulk create, or optimize this in some way
print(detection)
assert detection.algorithm
algo, _created = Algorithm.objects.get_or_create(
Expand All @@ -153,12 +154,12 @@ def save_results(results: PipelineResponse, job_id: int | None = None) -> list[m
new_detection = Detection.objects.create(
source_image=source_image,
bbox=list(detection.bbox.dict().values()),
timestamp=source_image.timestamp,
path=detection.crop_image_url or "",
detection_time=detection.timestamp,
)
new_detection.detection_algorithm = algo
# new_detection.detection_time = detection.inference_time
new_detection.timestamp = now() # @TODO what is this field for
# @TODO lookup and assign related algorithm object
# new_detection.detection_algorithm = detection.algorithm
new_detection.save()
print("Created new detection", new_detection)
created_objects.append(new_detection)
Expand Down
44 changes: 43 additions & 1 deletion ami/utils/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,50 @@ def group_datetimes_by_gap(
return groups


def group_datetimes_by_shifted_day(timestamps: list[datetime.datetime]) -> list[list[datetime.datetime]]:
"""
@TODO: Needs testing
Images are captured from Evening to Morning the next day.
Assume that the first image is taken after noon and the last image is taken before noon.
In that case, we can shift the timestamps so that the x-axis is centered around 12PM.
then group the images by day.
One way to do this directly in postgres is to use the following query:
SELECT date_trunc('day', timestamp + interval '12 hours') as day, count(*)
FROM images
GROUP BY day
>>> timestamps = [
... datetime.datetime(2021, 1, 1, 0, 0, 0),
... datetime.datetime(2021, 1, 1, 0, 1, 0),
... datetime.datetime(2021, 1, 1, 0, 2, 0),
... datetime.datetime(2021, 1, 2, 0, 0, 0),
... datetime.datetime(2021, 1, 2, 0, 1, 0),
... datetime.datetime(2021, 1, 2, 0, 2, 0),]
>>> result = group_datetimes_by_shifted_day(timestamps)
>>> len(result)
2
"""

# Shift hours so that the x-axis is centered around 12PM.
time_delta = datetime.timedelta(hours=12)
timestamps = [timestamp - time_delta for timestamp in sorted(timestamps)]

# Group the timestamps by their day value:
groups = {}
for timestamp in timestamps:
day = timestamp.date()
if day not in groups:
groups[day] = []
groups[day].append(timestamp)

# Convert the dictionary to a list of lists
return list(groups.values())


def shift_to_nighttime(hours: list[int], values: list) -> tuple[list[int], list]:
"""Shift hours so that the x-axis is centered around 12PM."""
"""Another strategy to shift hours so that the x-axis is centered around 12PM."""

split_index = 0
for i, hour in enumerate(hours):
Expand Down
Loading

0 comments on commit d37c2cb

Please sign in to comment.