From 9181362f479b19397716b825f81ad9d40874e45e Mon Sep 17 00:00:00 2001 From: Mohamed Elabbas Date: Wed, 15 Jan 2025 19:52:40 -0500 Subject: [PATCH 1/5] Renamed `project` filter query parameter to `project_id` for Project Summary entities (#668) * Renamed "project" filter query parameter to "project_id" for Project Summary entities * Fixed project association in create_ml_pipeline function * Added tests for filtering by "project_id" on Project Summary entities * Refactored open api project_id docs params, get_project logic and moved it to requests.py * Applied changes to the frontend * Updated all entities in the project page to filter by project_id * Updated tests to use project_id parameter instead of project * Removed occurrences__project query param --- ami/jobs/views.py | 11 ++- ami/main/api/serializers.py | 2 +- ami/main/api/views.py | 106 ++++++++++++++++++++------ ami/main/tests.py | 136 +++++++++++++++++++++++++++++++++- ami/ml/views.py | 16 ++++ ami/tests/fixtures/main.py | 2 +- ami/utils/requests.py | 18 +++++ ui/src/data-services/utils.ts | 2 +- 8 files changed, 261 insertions(+), 32 deletions(-) diff --git a/ami/jobs/views.py b/ami/jobs/views.py index f8f2e34a2..e783ff9a5 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -3,11 +3,13 @@ from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone +from drf_spectacular.utils import extend_schema from rest_framework.decorators import action from rest_framework.response import Response from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param +from ami.utils.requests import get_active_project, project_id_doc_param from .models import Job, JobState, MLJob from .serializers import JobListSerializer, JobSerializer @@ -35,7 +37,6 @@ class JobViewSet(DefaultViewSet): """ queryset = Job.objects.select_related( - "project", "deployment", "pipeline", "source_image_collection", @@ -128,7 +129,9 @@ def perform_create(self, serializer): def get_queryset(self) -> QuerySet: jobs = super().get_queryset() - + project = get_active_project(self.request) + if project: + jobs = jobs.filter(project=project) cutoff_hours = IntegerField(required=False, min_value=0).clean( self.request.query_params.get("cutoff_hours", Job.FAILED_CUTOFF_HOURS) ) @@ -138,3 +141,7 @@ def get_queryset(self) -> QuerySet: status=JobState.failed_states(), updated_at__lt=cutoff_datetime, ) + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index eb7164489..7c96a052b 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -503,7 +503,7 @@ def get_occurrence_images(self, obj): # request = self.context.get("request") # project_id = request.query_params.get("project") if request else None - project_id = self.context["request"].query_params["project"] + project_id = self.context["request"].query_params["project_id"] classification_threshold = get_active_classification_threshold(self.context["request"]) return obj.occurrence_images( diff --git a/ami/main/api/views.py b/ami/main/api/views.py index e8e76e537..7ca73bb30 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -10,6 +10,7 @@ from django.forms import BooleanField, CharField, IntegerField from django.utils import timezone from django_filters.rest_framework import DjangoFilterBackend +from drf_spectacular.utils import extend_schema from rest_framework import exceptions as api_exceptions from rest_framework import filters, serializers, status, viewsets from rest_framework.decorators import action @@ -24,7 +25,7 @@ from ami.base.pagination import LimitOffsetPaginationWithPermissions from ami.base.permissions import IsActiveStaffOrReadOnly from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer -from ami.utils.requests import get_active_classification_threshold +from ami.utils.requests import get_active_classification_threshold, get_active_project, project_id_doc_param from ami.utils.storages import ConnectionTestResult from ..models import ( @@ -137,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet): """ queryset = Deployment.objects.select_related("project", "device", "research_site") - filterset_fields = ["project"] ordering_fields = [ "created_at", "updated_at", @@ -160,7 +160,9 @@ def get_serializer_class(self): def get_queryset(self) -> QuerySet: qs = super().get_queryset() - + project = get_active_project(self.request) + if project: + qs = qs.filter(project=project) num_example_captures = 10 if self.action == "retrieve": qs = qs.prefetch_related( @@ -204,6 +206,10 @@ def sync(self, _request, pk=None) -> Response: else: raise api_exceptions.ValidationError(detail="Deployment must have a data source to sync captures from") + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class EventViewSet(DefaultViewSet): """ @@ -212,7 +218,7 @@ class EventViewSet(DefaultViewSet): queryset = Event.objects.all() serializer_class = EventSerializer - filterset_fields = ["deployment", "project"] + filterset_fields = ["deployment"] ordering_fields = [ "created_at", "updated_at", @@ -237,6 +243,9 @@ def get_serializer_class(self): def get_queryset(self) -> QuerySet: qs: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + qs = qs.filter(project=project) qs = qs.filter(deployment__isnull=False) qs = qs.annotate( duration=models.F("end") - models.F("start"), @@ -363,6 +372,10 @@ def timeline(self, request, pk=None): ) return Response(serializer.data) + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class SourceImageViewSet(DefaultViewSet): """ @@ -549,7 +562,7 @@ class SourceImageCollectionViewSet(DefaultViewSet): ) serializer_class = SourceImageCollectionSerializer - filterset_fields = ["project", "method"] + filterset_fields = ["method"] ordering_fields = [ "created_at", "updated_at", @@ -562,11 +575,14 @@ class SourceImageCollectionViewSet(DefaultViewSet): def get_queryset(self) -> QuerySet: classification_threshold = get_active_classification_threshold(self.request) - queryset = ( - super() - .get_queryset() - .with_occurrences_count(classification_threshold=classification_threshold) # type: ignore - .with_taxa_count(classification_threshold=classification_threshold) + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + queryset = query_set.with_occurrences_count( + classification_threshold=classification_threshold + ).with_taxa_count( # type: ignore + classification_threshold=classification_threshold ) return queryset @@ -647,6 +663,10 @@ def remove(self, request, pk=None): } ) + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class SourceImageUploadViewSet(DefaultViewSet): """ @@ -903,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet): filterset_fields = [ "event", "deployment", - "project", "determination__rank", "detections__source_image", ] @@ -933,7 +952,10 @@ def get_serializer_class(self): return OccurrenceSerializer def get_queryset(self) -> QuerySet: + project = get_active_project(self.request) qs = super().get_queryset() + if project: + qs = qs.filter(project=project) qs = qs.select_related( "determination", "deployment", @@ -961,6 +983,10 @@ def get_queryset(self) -> QuerySet: return qs + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class TaxonViewSet(DefaultViewSet): """ @@ -1042,23 +1068,22 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]: """ occurrence_id = self.request.query_params.get("occurrence") - project_id = self.request.query_params.get("project") or self.request.query_params.get("occurrences__project") + project = get_active_project(self.request) deployment_id = self.request.query_params.get("deployment") or self.request.query_params.get( "occurrences__deployment" ) event_id = self.request.query_params.get("event") or self.request.query_params.get("occurrences__event") collection_id = self.request.query_params.get("collection") - filter_active = any([occurrence_id, project_id, deployment_id, event_id, collection_id]) + filter_active = any([occurrence_id, project, deployment_id, event_id, collection_id]) - if not project_id: + if not project: # Raise a 400 if no project is specified raise api_exceptions.ValidationError(detail="A project must be specified") queryset = super().get_queryset() try: - if project_id: - project = Project.objects.get(id=project_id) + if project: queryset = queryset.filter(occurrences__project=project) if occurrence_id: occurrence = Occurrence.objects.get(id=occurrence_id) @@ -1180,6 +1205,10 @@ def get_queryset(self) -> QuerySet: return qs + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + # def retrieve(self, request: Request, *args, **kwargs) -> Response: # """ # Override the serializer to include the recursive occurrences count @@ -1207,16 +1236,15 @@ class ClassificationViewSet(DefaultViewSet): class SummaryView(GenericAPIView): permission_classes = [IsActiveStaffOrReadOnly] - filterset_fields = ["project"] + @extend_schema(parameters=[project_id_doc_param]) def get(self, request): """ Return counts of all models. """ - project_id = request.query_params.get("project") + project = get_active_project(request) confidence_threshold = get_active_classification_threshold(request) - if project_id: - project = Project.objects.get(id=project_id) + if project: data = { "projects_count": Project.objects.count(), # @TODO filter by current user, here and everywhere! "deployments_count": Deployment.objects.filter(project=project).count(), @@ -1358,13 +1386,24 @@ class SiteViewSet(DefaultViewSet): queryset = Site.objects.all() serializer_class = SiteSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class DeviceViewSet(DefaultViewSet): """ @@ -1373,13 +1412,24 @@ class DeviceViewSet(DefaultViewSet): queryset = Device.objects.all() serializer_class = DeviceSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class StorageSourceConnectionTestSerializer(serializers.Serializer): subdir = serializers.CharField(required=False, allow_null=True) @@ -1393,13 +1443,23 @@ class StorageSourceViewSet(DefaultViewSet): queryset = S3StorageSource.objects.all() serializer_class = StorageSourceSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + @action(detail=True, methods=["post"], name="test", serializer_class=StorageSourceConnectionTestSerializer) def test(self, request: Request, pk=None) -> Response: """ diff --git a/ami/main/tests.py b/ami/main/tests.py index d422db060..93f5a87da 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -1,12 +1,26 @@ import datetime import logging -from django.db import connection +from django.db import connection, models from django.test import TestCase +from rest_framework import status from rest_framework.test import APIRequestFactory, APITestCase from rich import print -from ami.main.models import Event, Occurrence, Project, Taxon, TaxonRank, group_images_into_events +from ami.main.models import ( + Device, + Event, + Occurrence, + Project, + S3StorageSource, + Site, + SourceImage, + SourceImageCollection, + Taxon, + TaxonRank, + group_images_into_events, +) +from ami.ml.models.pipeline import Pipeline from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project from ami.users.models import User @@ -535,7 +549,7 @@ def setUp(self) -> None: def test_occurrences_for_project(self): # Test that occurrences are specific to each project for project in [self.project_one, self.project_two]: - response = self.client.get(f"/api/v2/occurrences/?project={project.pk}") + response = self.client.get(f"/api/v2/occurrences/?project_id={project.pk}") self.assertEqual(response.status_code, 200) self.assertEqual(response.json()["count"], Occurrence.objects.filter(project=project).count()) @@ -578,7 +592,7 @@ def _test_taxa_for_project(self, project: Project): """ from ami.main.models import Taxon - response = self.client.get(f"/api/v2/taxa/?project={project.pk}") + response = self.client.get(f"/api/v2/taxa/?project_id={project.pk}") self.assertEqual(response.status_code, 200) project_occurred_taxa = Taxon.objects.filter(occurrences__project=project).distinct() # project_any_taxa = Taxon.objects.filter(projects=project) @@ -754,3 +768,117 @@ def test_update_subdir(self): self.other_subdir: self.images_per_dir, } self.assertDictEqual(dict(counts), expected_counts) + + +class TestProjectSettingsFiltering(APITestCase): + """Test Project Settings filter by project_id""" + + def setUp(self) -> None: + for _ in range(3): + project, deployment = setup_test_project(reuse=False) + create_taxa(project=project) + create_captures(deployment=deployment) + group_images_into_events(deployment=deployment) + create_occurrences(deployment=deployment, num=5) + self.project_ids = [project.id for project in Project.objects.all()] + + self.user = User.objects.create_user( # type: ignore + email="testuser@insectai.org", + is_staff=True, + ) + self.factory = APIRequestFactory() + self.client.force_authenticate(user=self.user) + return super().setUp() + + def test_project_summary(self): + project_id = self.project_ids[1] + endpoint_url = f"/api/v2/status/summary/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + self.assertEqual(response.status_code, status.HTTP_200_OK) + project = Project.objects.get(pk=project_id) + + self.assertEqual(response_data["deployments_count"], project.deployments_count()) + self.assertEqual( + response_data["taxa_count"], + Taxon.objects.annotate(occurrences_count=models.Count("occurrences")) + .filter( + occurrences_count__gt=0, + occurrences__determination_score__gte=0, + occurrences__project=project, + ) + .distinct() + .count(), + ) + self.assertEqual( + response_data["events_count"], + Event.objects.filter(deployment__project=project, deployment__isnull=False).count(), + ) + self.assertEqual( + response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count() + ) + self.assertEqual( + response_data["occurrences_count"], + Occurrence.objects.filter( + project=project, + determination_score__gte=0, + event__isnull=False, + ).count(), + ) + self.assertEqual( + response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count() + ) + + def test_project_collections(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/captures/collections/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + expected_project_collection_ids = { + source_image_collection.id + for source_image_collection in SourceImageCollection.objects.filter(project=project) + } + response_source_image_collection_ids = {result.get("id") for result in response_data["results"]} + self.assertEqual(response_source_image_collection_ids, expected_project_collection_ids) + + def test_project_pipelines(self): + project_id = self.project_ids[0] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/ml/pipelines/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + + expected_project_pipeline_ids = {pipeline.id for pipeline in Pipeline.objects.filter(projects=project)} + response_pipeline_ids = {pipeline.get("id") for pipeline in response_data["results"]} + self.assertEqual(response_pipeline_ids, expected_project_pipeline_ids) + + def test_project_storage(self): + project_id = self.project_ids[0] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/storage/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + expected_storage_ids = {storage.id for storage in S3StorageSource.objects.filter(project=project)} + response_storage_ids = {storage.get("id") for storage in response_data["results"]} + self.assertEqual(response_storage_ids, expected_storage_ids) + + def test_project_sites(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/deployments/sites/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + exepcted_site_ids = {site.id for site in Site.objects.filter(project=project)} + response_site_ids = {site.get("id") for site in response_data["results"]} + self.assertEqual(response_site_ids, exepcted_site_ids) + + def test_project_devices(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/deployments/devices/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + exepcted_device_ids = {device.id for device in Device.objects.filter(project=project)} + response_device_ids = {device.get("id") for device in response_data["results"]} + self.assertEqual(response_device_ids, exepcted_device_ids) diff --git a/ami/ml/views.py b/ami/ml/views.py index b110373d1..2ee4d8dd0 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -1,4 +1,8 @@ +from django.db.models.query import QuerySet +from drf_spectacular.utils import extend_schema + from ami.main.api.views import DefaultViewSet +from ami.utils.requests import get_active_project, project_id_doc_param from .models.algorithm import Algorithm from .models.pipeline import Pipeline @@ -35,6 +39,18 @@ class PipelineViewSet(DefaultViewSet): "created_at", "updated_at", ] + + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(projects=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + # Don't enable projects filter until we can use the current users # membership to filter the projects. # filterset_fields = ["projects"] diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index dd9ec761f..6d8164bbf 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -70,7 +70,7 @@ def create_ml_pipeline(project): for algorithm_data in pipeline_data["algorithms"]: algorithm, _ = Algorithm.objects.get_or_create(name=algorithm_data["name"], key=algorithm_data["key"]) pipeline.algorithms.add(algorithm) - + pipeline.projects.add(project) pipeline.save() return pipeline diff --git a/ami/utils/requests.py b/ami/utils/requests.py index 832eef19a..50edb0652 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -1,6 +1,9 @@ from django.forms import FloatField +from drf_spectacular.utils import OpenApiParameter from rest_framework.request import Request +from ami.main.models import Project + def get_active_classification_threshold(request: Request) -> float: # Look for a query param to filter by score @@ -11,3 +14,18 @@ def get_active_classification_threshold(request: Request) -> float: else: classification_threshold = 0 return classification_threshold + + +def get_active_project(request: Request) -> Project | None: + project_id = request.query_params.get("project_id") + if project_id: + return Project.objects.filter(id=project_id).first() + return None + + +project_id_doc_param = OpenApiParameter( + name="project_id", + description="Filter by project ID", + required=False, + type=int, +) diff --git a/ui/src/data-services/utils.ts b/ui/src/data-services/utils.ts index ee04806c7..c9d5c57e9 100644 --- a/ui/src/data-services/utils.ts +++ b/ui/src/data-services/utils.ts @@ -14,7 +14,7 @@ export const getFetchUrl = ({ const queryParams: QueryParams = {} if (params?.projectId) { - queryParams.project = params?.projectId + queryParams.project_id = params?.projectId } if (params?.sort) { const order = params.sort.order === 'asc' ? '' : '-' From 2dadd698384e43329e5f9377c51d97cff48f62eb Mon Sep 17 00:00:00 2001 From: Anna Viklund Date: Fri, 17 Jan 2025 01:41:26 +0100 Subject: [PATCH 2/5] Fix update password form (#670) * fix: update form layout to avoid nested forms * style: update order of fields and tweak box styles * style: tweak layout --- .../user-info-form/user-info-form.module.scss | 1 + .../user-info-form/user-info-form.tsx | 40 +++++++++++-------- .../user-info-form/user-password-field.tsx | 4 +- ui/src/utils/language.ts | 2 +- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss index a73d0a3a8..b7b9a6c3c 100644 --- a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss +++ b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss @@ -4,6 +4,7 @@ .miniForm { border-radius: 8px; border: 1px solid $color-neutral-100; + overflow: hidden; } .miniFormContent { diff --git a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx index d3792fffc..d0f02e0ca 100644 --- a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx +++ b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx @@ -11,6 +11,7 @@ import { useUpdateUserInfo } from 'data-services/hooks/auth/useUpdateUserInfo' import { Button, ButtonTheme } from 'design-system/components/button/button' import { IconType } from 'design-system/components/icon/icon' import { InputContent } from 'design-system/components/input/input' +import { useRef } from 'react' import { useForm } from 'react-hook-form' import { API_MAX_UPLOAD_SIZE } from 'utils/constants' import { STRING, translate } from 'utils/language' @@ -55,6 +56,7 @@ const config: FormConfig = { } export const UserInfoForm = ({ userInfo }: { userInfo: UserInfo }) => { + const formRef = useRef(null) const { control, handleSubmit, @@ -68,29 +70,29 @@ export const UserInfoForm = ({ userInfo }: { userInfo: UserInfo }) => { const errorMessage = useFormError({ error, setFieldError }) return ( -
updateUserInfo(values))}> - {errorMessage && ( - - )} + <> + {errorMessage && ( + + )} - - <> - + updateUserInfo(values))} + className="grid gap-8" + > - - { )} /> - - + +