Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to collection sampling methods #717

Merged
merged 6 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions ami/base/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import datetime

from rest_framework import serializers


class DateStringField(serializers.CharField):
"""
Field that validates and stores dates as YYYY-MM-DD strings.
Needed for storing dates as strings in JSON fields but keep validation.
"""

def to_internal_value(self, value: str | None) -> str | None:
if value is None:
return None

try:
# Validate the date format by parsing it
datetime.datetime.strptime(value, "%Y-%m-%d")
return value
except ValueError as e:
raise serializers.ValidationError("Invalid date format. Use YYYY-MM-DD format.") from e

@classmethod
def to_date(cls, value: str | None) -> datetime.date | None:
"""Convert a YYYY-MM-DD string to a Python date object for ORM queries."""
if value is None:
return None
return datetime.datetime.strptime(value, "%Y-%m-%d").date()
8 changes: 7 additions & 1 deletion ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db.models import QuerySet
from rest_framework import serializers

from ami.base.fields import DateStringField
from ami.base.serializers import DefaultSerializer, MinimalNestedModelSerializer, get_current_user, reverse_with_params
from ami.jobs.models import Job
from ami.main.models import create_source_image_from_upload
Expand Down Expand Up @@ -1025,9 +1026,14 @@ class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
# use for the "common_combined" method
minute_interval = serializers.IntegerField(required=False, allow_null=True)
max_num = serializers.IntegerField(required=False, allow_null=True)
shuffle = serializers.BooleanField(required=False, allow_null=True)

month_start = serializers.IntegerField(required=False, allow_null=True)
month_end = serializers.IntegerField(required=False, allow_null=True)

date_start = DateStringField(required=False, allow_null=True)
date_end = DateStringField(required=False, allow_null=True)

hour_start = serializers.IntegerField(required=False, allow_null=True)
hour_end = serializers.IntegerField(required=False, allow_null=True)

Expand All @@ -1039,9 +1045,9 @@ class SourceImageCollectionCommonKwargsSerializer(serializers.Serializer):
deployment_id = serializers.IntegerField(required=False, allow_null=True)
position = serializers.IntegerField(required=False, allow_null=True)

# Don't return the kwargs if they are empty
def to_representation(self, instance):
data = super().to_representation(instance)
# Don't return the kwargs if they are empty
return {key: value for key, value in data.items() if value is not None}


Expand Down
74 changes: 41 additions & 33 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import ami.tasks
import ami.utils
from ami.base.fields import DateStringField
from ami.base.models import BaseModel
from ami.main import charts
from ami.users.models import User
Expand Down Expand Up @@ -1521,8 +1522,8 @@ def set_dimensions_for_collection(


def sample_captures_by_interval(
minute_interval: int = 10,
qs: models.QuerySet[SourceImage] | None = None,
minute_interval: int,
qs: models.QuerySet[SourceImage],
max_num: int | None = None,
) -> typing.Generator[SourceImage, None, None]:
"""
Expand All @@ -1532,9 +1533,6 @@ def sample_captures_by_interval(
last_capture = None
total = 0

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

for capture in qs.all():
Expand All @@ -1555,7 +1553,7 @@ def sample_captures_by_interval(

def sample_captures_by_position(
position: int,
qs: models.QuerySet[SourceImage] | None = None,
qs: models.QuerySet[SourceImage],
) -> typing.Generator[SourceImage | None, None, None]:
"""
Return the n-th position capture from each event.
Expand All @@ -1564,9 +1562,6 @@ def sample_captures_by_position(
If position = -1, the last capture from each event will be returned.
"""

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

events = Event.objects.filter(captures__in=qs).distinct()
Expand All @@ -1593,7 +1588,7 @@ def sample_captures_by_position(

def sample_captures_by_nth(
nth: int,
qs: models.QuerySet[SourceImage] | None = None,
qs: models.QuerySet[SourceImage],
) -> typing.Generator[SourceImage, None, None]:
"""
Return every nth capture from each event.
Expand All @@ -1602,9 +1597,6 @@ def sample_captures_by_nth(
If nth = 5, every 5th capture from each event will be returned.
"""

if not qs:
raise ValueError("Queryset must be provided, and it should be limited to a Project.")

qs = qs.exclude(timestamp=None).order_by("timestamp")

events = Event.objects.filter(captures__in=qs).distinct()
Expand Down Expand Up @@ -2973,35 +2965,51 @@ def sample_manual(self, image_ids: list[int]):
def sample_common_combined(
self,
minute_interval: int | None = None,
max_num: int | None = 100,
max_num: int | None = None,
shuffle: bool = True, # This is applicable if max_num is set and minute_interval is not set
hour_start: int | None = None,
hour_end: int | None = None,
month_start: datetime.date | None = None,
month_end: datetime.date | None = None,
day_start: datetime.date | None = None,
day_end: datetime.date | None = None,
month_start: int | None = None,
month_end: int | None = None,
date_start: str | None = None,
date_end: str | None = None,
) -> models.QuerySet | typing.Generator[SourceImage, None, None]:
qs = self.get_queryset()
if month_start:

if date_start is not None:
qs = qs.filter(timestamp__date__gte=DateStringField.to_date(date_start))
if date_end is not None:
qs = qs.filter(timestamp__date__lte=DateStringField.to_date(date_end))

if month_start is not None:
qs = qs.filter(timestamp__month__gte=month_start)
if month_end:
if month_end is not None:
qs = qs.filter(timestamp__month__lte=month_end)
if day_start:
qs = qs.filter(timestamp__day__gte=day_start)
if day_end:
qs = qs.filter(timestamp__day__lte=day_end)
if hour_start:

if hour_start is not None and hour_end is not None:
if hour_start < hour_end:
# Hour range within the same day (e.g., 08:00 to 15:00)
qs = qs.filter(timestamp__hour__gte=hour_start, timestamp__hour__lte=hour_end)
else:
# Hour range has Midnight crossover: (e.g., 17:00 to 06:00)
qs = qs.filter(models.Q(timestamp__hour__gte=hour_start) | models.Q(timestamp__hour__lte=hour_end))
elif hour_start is not None:
qs = qs.filter(timestamp__hour__gte=hour_start)
if hour_end:
elif hour_end is not None:
qs = qs.filter(timestamp__hour__lte=hour_end)
if not minute_interval and max_num:
qs = qs[:max_num]
if minute_interval:

if minute_interval is not None:
# @TODO can this be done in the database and return a queryset?
# this currently returns a list of source images
# Ensure the queryset is limited to the project
qs = qs.filter(project=self.project)
qs = sample_captures_by_interval(minute_interval, qs=qs, max_num=max_num)
qs = sample_captures_by_interval(minute_interval=minute_interval, qs=qs, max_num=max_num)
else:
if max_num is not None:
if shuffle:
qs = qs.order_by("?")
qs = qs[:max_num]

return qs

def sample_interval(
Expand All @@ -3016,19 +3024,19 @@ def sample_interval(
qs = qs.exclude(event__in=exclude_events)
qs.exclude(event__in=exclude_events)
qs = qs.filter(project=self.project)
return sample_captures_by_interval(minute_interval, qs=qs)
return sample_captures_by_interval(minute_interval=minute_interval, qs=qs)

def sample_positional(self, position: int = -1):
"""Sample the single nth source image from all events in the project"""

qs = self.get_queryset()
return sample_captures_by_position(position, qs=qs)
return sample_captures_by_position(position=position, qs=qs)

def sample_nth(self, nth: int):
"""Sample every nth source image from all events in the project"""

qs = self.get_queryset()
return sample_captures_by_nth(nth, qs=qs)
return sample_captures_by_nth(nth=nth, qs=qs)

def sample_random_from_each_event(self, num_each: int = 10):
"""Sample n random source images from each event in the project."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { FormField } from 'components/form/form-field'
import { isValid } from 'date-fns'

import {
FormActions,
FormError,
Expand All @@ -24,9 +26,29 @@ type CollectionFormValues = FormValues & {
month_end: string | undefined
hour_start: number | undefined
hour_end: number | undefined
date_start: string | undefined
date_end: string | undefined
}
}

// simple date string config

const kwargs_date_config = {
label: 'Date',
description: 'Format: YYYY-MM-DD',
rules: {
validate: (value: any): string | undefined => {
if (!value) return undefined

if (!isValid(new Date(value))) {
return 'Date must be in YYYY-MM-DD format'
}

return undefined
},
},
}

const config: FormConfig = {
name: {
label: translate(STRING.FIELD_LABEL_NAME),
Expand All @@ -51,6 +73,7 @@ const config: FormConfig = {
},
'kwargs.max_num': {
label: 'Max number of images',
description: 'When set, the collection will be a random sample',
},
'kwargs.minute_interval': {
label: 'Minutes between captures',
Expand All @@ -67,6 +90,14 @@ const config: FormConfig = {
'kwargs.hour_end': {
label: 'Latest hour',
},
'kwargs.date_start': {
...kwargs_date_config,
label: 'Earliest date',
},
'kwargs.date_end': {
...kwargs_date_config,
label: 'Latest date',
},
}

export const CollectionDetailsForm = ({
Expand Down Expand Up @@ -179,6 +210,20 @@ export const CollectionDetailsForm = ({
control={control}
/>
</FormRow>
<FormRow>
<FormField
name="kwargs.date_start"
type="text"
config={config}
control={control}
/>
<FormField
name="kwargs.date_end"
type="text"
config={config}
control={control}
/>
</FormRow>
<FormRow>
<FormField
name="method"
Expand Down