Skip to content

Commit

Permalink
Renamed project filter query parameter to project_id for Project …
Browse files Browse the repository at this point in the history
…Summary entities (#668)

* Renamed "project" filter query parameter to "project_id" for Project Summary entities

* Fixed project association in create_ml_pipeline function

* Added tests for filtering by "project_id" on Project Summary entities

* Refactored open api project_id docs params, get_project logic and moved it to requests.py

* Applied changes to the frontend

* Updated all entities in the project page to filter by project_id

* Updated tests to use project_id parameter instead of project

* Removed occurrences__project query param
  • Loading branch information
mohamedelabbas1996 authored Jan 16, 2025
1 parent 284cb14 commit 9181362
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 32 deletions.
11 changes: 9 additions & 2 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from django.db.models.query import QuerySet
from django.forms import IntegerField
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.response import Response

from ami.main.api.views import DefaultViewSet
from ami.utils.fields import url_boolean_param
from ami.utils.requests import get_active_project, project_id_doc_param

from .models import Job, JobState, MLJob
from .serializers import JobListSerializer, JobSerializer
Expand Down Expand Up @@ -35,7 +37,6 @@ class JobViewSet(DefaultViewSet):
"""

queryset = Job.objects.select_related(
"project",
"deployment",
"pipeline",
"source_image_collection",
Expand Down Expand Up @@ -128,7 +129,9 @@ def perform_create(self, serializer):

def get_queryset(self) -> QuerySet:
jobs = super().get_queryset()

project = get_active_project(self.request)
if project:
jobs = jobs.filter(project=project)
cutoff_hours = IntegerField(required=False, min_value=0).clean(
self.request.query_params.get("cutoff_hours", Job.FAILED_CUTOFF_HOURS)
)
Expand All @@ -138,3 +141,7 @@ def get_queryset(self) -> QuerySet:
status=JobState.failed_states(),
updated_at__lt=cutoff_datetime,
)

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)
2 changes: 1 addition & 1 deletion ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def get_occurrence_images(self, obj):

# request = self.context.get("request")
# project_id = request.query_params.get("project") if request else None
project_id = self.context["request"].query_params["project"]
project_id = self.context["request"].query_params["project_id"]
classification_threshold = get_active_classification_threshold(self.context["request"])

return obj.occurrence_images(
Expand Down
106 changes: 83 additions & 23 deletions ami/main/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from django.forms import BooleanField, CharField, IntegerField
from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema
from rest_framework import exceptions as api_exceptions
from rest_framework import filters, serializers, status, viewsets
from rest_framework.decorators import action
Expand All @@ -24,7 +25,7 @@
from ami.base.pagination import LimitOffsetPaginationWithPermissions
from ami.base.permissions import IsActiveStaffOrReadOnly
from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer
from ami.utils.requests import get_active_classification_threshold
from ami.utils.requests import get_active_classification_threshold, get_active_project, project_id_doc_param
from ami.utils.storages import ConnectionTestResult

from ..models import (
Expand Down Expand Up @@ -137,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet):
"""

queryset = Deployment.objects.select_related("project", "device", "research_site")
filterset_fields = ["project"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -160,7 +160,9 @@ def get_serializer_class(self):

def get_queryset(self) -> QuerySet:
qs = super().get_queryset()

project = get_active_project(self.request)
if project:
qs = qs.filter(project=project)
num_example_captures = 10
if self.action == "retrieve":
qs = qs.prefetch_related(
Expand Down Expand Up @@ -204,6 +206,10 @@ def sync(self, _request, pk=None) -> Response:
else:
raise api_exceptions.ValidationError(detail="Deployment must have a data source to sync captures from")

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class EventViewSet(DefaultViewSet):
"""
Expand All @@ -212,7 +218,7 @@ class EventViewSet(DefaultViewSet):

queryset = Event.objects.all()
serializer_class = EventSerializer
filterset_fields = ["deployment", "project"]
filterset_fields = ["deployment"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -237,6 +243,9 @@ def get_serializer_class(self):

def get_queryset(self) -> QuerySet:
qs: QuerySet = super().get_queryset()
project = get_active_project(self.request)
if project:
qs = qs.filter(project=project)
qs = qs.filter(deployment__isnull=False)
qs = qs.annotate(
duration=models.F("end") - models.F("start"),
Expand Down Expand Up @@ -363,6 +372,10 @@ def timeline(self, request, pk=None):
)
return Response(serializer.data)

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class SourceImageViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -549,7 +562,7 @@ class SourceImageCollectionViewSet(DefaultViewSet):
)
serializer_class = SourceImageCollectionSerializer

filterset_fields = ["project", "method"]
filterset_fields = ["method"]
ordering_fields = [
"created_at",
"updated_at",
Expand All @@ -562,11 +575,14 @@ class SourceImageCollectionViewSet(DefaultViewSet):

def get_queryset(self) -> QuerySet:
classification_threshold = get_active_classification_threshold(self.request)
queryset = (
super()
.get_queryset()
.with_occurrences_count(classification_threshold=classification_threshold) # type: ignore
.with_taxa_count(classification_threshold=classification_threshold)
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
if project:
query_set = query_set.filter(project=project)
queryset = query_set.with_occurrences_count(
classification_threshold=classification_threshold
).with_taxa_count( # type: ignore
classification_threshold=classification_threshold
)
return queryset

Expand Down Expand Up @@ -647,6 +663,10 @@ def remove(self, request, pk=None):
}
)

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class SourceImageUploadViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -903,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet):
filterset_fields = [
"event",
"deployment",
"project",
"determination__rank",
"detections__source_image",
]
Expand Down Expand Up @@ -933,7 +952,10 @@ def get_serializer_class(self):
return OccurrenceSerializer

def get_queryset(self) -> QuerySet:
project = get_active_project(self.request)
qs = super().get_queryset()
if project:
qs = qs.filter(project=project)
qs = qs.select_related(
"determination",
"deployment",
Expand Down Expand Up @@ -961,6 +983,10 @@ def get_queryset(self) -> QuerySet:

return qs

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class TaxonViewSet(DefaultViewSet):
"""
Expand Down Expand Up @@ -1042,23 +1068,22 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]:
"""

occurrence_id = self.request.query_params.get("occurrence")
project_id = self.request.query_params.get("project") or self.request.query_params.get("occurrences__project")
project = get_active_project(self.request)
deployment_id = self.request.query_params.get("deployment") or self.request.query_params.get(
"occurrences__deployment"
)
event_id = self.request.query_params.get("event") or self.request.query_params.get("occurrences__event")
collection_id = self.request.query_params.get("collection")

filter_active = any([occurrence_id, project_id, deployment_id, event_id, collection_id])
filter_active = any([occurrence_id, project, deployment_id, event_id, collection_id])

if not project_id:
if not project:
# Raise a 400 if no project is specified
raise api_exceptions.ValidationError(detail="A project must be specified")

queryset = super().get_queryset()
try:
if project_id:
project = Project.objects.get(id=project_id)
if project:
queryset = queryset.filter(occurrences__project=project)
if occurrence_id:
occurrence = Occurrence.objects.get(id=occurrence_id)
Expand Down Expand Up @@ -1180,6 +1205,10 @@ def get_queryset(self) -> QuerySet:

return qs

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

# def retrieve(self, request: Request, *args, **kwargs) -> Response:
# """
# Override the serializer to include the recursive occurrences count
Expand Down Expand Up @@ -1207,16 +1236,15 @@ class ClassificationViewSet(DefaultViewSet):

class SummaryView(GenericAPIView):
permission_classes = [IsActiveStaffOrReadOnly]
filterset_fields = ["project"]

@extend_schema(parameters=[project_id_doc_param])
def get(self, request):
"""
Return counts of all models.
"""
project_id = request.query_params.get("project")
project = get_active_project(request)
confidence_threshold = get_active_classification_threshold(request)
if project_id:
project = Project.objects.get(id=project_id)
if project:
data = {
"projects_count": Project.objects.count(), # @TODO filter by current user, here and everywhere!
"deployments_count": Deployment.objects.filter(project=project).count(),
Expand Down Expand Up @@ -1358,13 +1386,24 @@ class SiteViewSet(DefaultViewSet):

queryset = Site.objects.all()
serializer_class = SiteSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class DeviceViewSet(DefaultViewSet):
"""
Expand All @@ -1373,13 +1412,24 @@ class DeviceViewSet(DefaultViewSet):

queryset = Device.objects.all()
serializer_class = DeviceSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
if project:
query_set = query_set.filter(project=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)


class StorageSourceConnectionTestSerializer(serializers.Serializer):
subdir = serializers.CharField(required=False, allow_null=True)
Expand All @@ -1393,13 +1443,23 @@ class StorageSourceViewSet(DefaultViewSet):

queryset = S3StorageSource.objects.all()
serializer_class = StorageSourceSerializer
filterset_fields = ["project", "deployments"]
filterset_fields = ["deployments"]
ordering_fields = [
"created_at",
"updated_at",
"name",
]

def get_queryset(self) -> QuerySet:
query_set: QuerySet = super().get_queryset()
project = get_active_project(self.request)
query_set = query_set.filter(project=project)
return query_set

@extend_schema(parameters=[project_id_doc_param])
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

@action(detail=True, methods=["post"], name="test", serializer_class=StorageSourceConnectionTestSerializer)
def test(self, request: Request, pk=None) -> Response:
"""
Expand Down
Loading

0 comments on commit 9181362

Please sign in to comment.