diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 000000000..5f6dbd990 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,34 @@ +## Summary +Brief description of what this PR does. (tl;dr). + +### List of Changes +* Modified class X +* Added model Y +* Fixed problem Z +* etc. + +### Related Issues +If the PR closes or is related to an issue, reference it here. +For example, "Closes #123", "Fixes #456" or "Relates to #741" . + +## Detailed Description +A clear and detailed description of the changes, how they solve/fix the related issues. + +Mention potential side effects or risks associated with the changes, if applicable. + +### How to Test the Changes +Instructions on how to test the changes Include references to automated and/or manual tests that were created/used to test the changes. + +### Screenshots +If applicable, add screenshots to help explain this PR (ex. Before and after for UI changes). + +## Deployment Notes +Include instructions if this PR requires specific steps for its deployment (database migrations, config changes, etc.) + +## Checklist + +- [ ] I have tested these changes appropriately. +- [ ] I have added and/or modified relevant tests. +- [ ] I updated relevant documentation or comments. +- [ ] I have verified that this PR follows the project's coding standards. +- [ ] Any dependent changes have already been merged to main. diff --git a/README.md b/README.md index 10783e2a8..a3e3b12da 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,12 @@ docker compose run --rm django python manage.py test -k pattern docker compose run --rm django python manage.py test -k pattern --failfast --pdb ``` +##### Speed up development of tests by reusing the db between test runs + +```bash +docker compose run --rm django python manage.py test --keepdb +``` + ##### Run management scripts ```bash diff --git a/ami/jobs/models.py b/ami/jobs/models.py index 33d54f726..3bd192106 100644 --- a/ami/jobs/models.py +++ b/ami/jobs/models.py @@ -285,6 +285,11 @@ class JobType: name: str key: str + # @TODO Consider adding custom vocabulary for job types to be used in the UI + # verb: str = "Sync" + # present_participle: str = "syncing" + # past_participle: str = "synced" + @classmethod def run(cls, job: "Job"): """ diff --git a/ami/jobs/views.py b/ami/jobs/views.py index f8f2e34a2..e783ff9a5 100644 --- a/ami/jobs/views.py +++ b/ami/jobs/views.py @@ -3,11 +3,13 @@ from django.db.models.query import QuerySet from django.forms import IntegerField from django.utils import timezone +from drf_spectacular.utils import extend_schema from rest_framework.decorators import action from rest_framework.response import Response from ami.main.api.views import DefaultViewSet from ami.utils.fields import url_boolean_param +from ami.utils.requests import get_active_project, project_id_doc_param from .models import Job, JobState, MLJob from .serializers import JobListSerializer, JobSerializer @@ -35,7 +37,6 @@ class JobViewSet(DefaultViewSet): """ queryset = Job.objects.select_related( - "project", "deployment", "pipeline", "source_image_collection", @@ -128,7 +129,9 @@ def perform_create(self, serializer): def get_queryset(self) -> QuerySet: jobs = super().get_queryset() - + project = get_active_project(self.request) + if project: + jobs = jobs.filter(project=project) cutoff_hours = IntegerField(required=False, min_value=0).clean( self.request.query_params.get("cutoff_hours", Job.FAILED_CUTOFF_HOURS) ) @@ -138,3 +141,7 @@ def get_queryset(self) -> QuerySet: status=JobState.failed_states(), updated_at__lt=cutoff_datetime, ) + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index eb7164489..01b06ad44 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -127,12 +127,40 @@ class Meta: ] +class JobTypeSerializer(serializers.Serializer): + """ + Serializer for the JobType json field in the Job model. + + This is duplicated from ami.jobs.serializers to avoid circular imports. + but it is extremely simple. + """ + + name = serializers.CharField(read_only=True) + key = serializers.SlugField(read_only=True) + + +class JobStatusSerializer(DefaultSerializer): + job_type = JobTypeSerializer(read_only=True) + + class Meta: + model = Job + fields = [ + "id", + "details", + "status", + "job_type", + "created_at", + "updated_at", + ] + + class DeploymentListSerializer(DefaultSerializer): events = serializers.SerializerMethodField() occurrences = serializers.SerializerMethodField() project = ProjectNestedSerializer(read_only=True) device = DeviceNestedSerializer(read_only=True) research_site = SiteNestedSerializer(read_only=True) + jobs = JobStatusSerializer(many=True, read_only=True) class Meta: model = Deployment @@ -156,6 +184,7 @@ class Meta: "last_date", "device", "research_site", + "jobs", ] def get_events(self, obj): @@ -503,7 +532,7 @@ def get_occurrence_images(self, obj): # request = self.context.get("request") # project_id = request.query_params.get("project") if request else None - project_id = self.context["request"].query_params["project"] + project_id = self.context["request"].query_params["project_id"] classification_threshold = get_active_classification_threshold(self.context["request"]) return obj.occurrence_images( @@ -849,18 +878,6 @@ class Meta: ] -class JobStatusSerializer(DefaultSerializer): - class Meta: - model = Job - fields = [ - "id", - "details", - "status", - "created_at", - "updated_at", - ] - - class SourceImageCollectionNestedSerializer(DefaultSerializer): class Meta: model = SourceImageCollection diff --git a/ami/main/api/views.py b/ami/main/api/views.py index e8e76e537..7ca73bb30 100644 --- a/ami/main/api/views.py +++ b/ami/main/api/views.py @@ -10,6 +10,7 @@ from django.forms import BooleanField, CharField, IntegerField from django.utils import timezone from django_filters.rest_framework import DjangoFilterBackend +from drf_spectacular.utils import extend_schema from rest_framework import exceptions as api_exceptions from rest_framework import filters, serializers, status, viewsets from rest_framework.decorators import action @@ -24,7 +25,7 @@ from ami.base.pagination import LimitOffsetPaginationWithPermissions from ami.base.permissions import IsActiveStaffOrReadOnly from ami.base.serializers import FilterParamsSerializer, SingleParamSerializer -from ami.utils.requests import get_active_classification_threshold +from ami.utils.requests import get_active_classification_threshold, get_active_project, project_id_doc_param from ami.utils.storages import ConnectionTestResult from ..models import ( @@ -137,7 +138,6 @@ class DeploymentViewSet(DefaultViewSet): """ queryset = Deployment.objects.select_related("project", "device", "research_site") - filterset_fields = ["project"] ordering_fields = [ "created_at", "updated_at", @@ -160,7 +160,9 @@ def get_serializer_class(self): def get_queryset(self) -> QuerySet: qs = super().get_queryset() - + project = get_active_project(self.request) + if project: + qs = qs.filter(project=project) num_example_captures = 10 if self.action == "retrieve": qs = qs.prefetch_related( @@ -204,6 +206,10 @@ def sync(self, _request, pk=None) -> Response: else: raise api_exceptions.ValidationError(detail="Deployment must have a data source to sync captures from") + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class EventViewSet(DefaultViewSet): """ @@ -212,7 +218,7 @@ class EventViewSet(DefaultViewSet): queryset = Event.objects.all() serializer_class = EventSerializer - filterset_fields = ["deployment", "project"] + filterset_fields = ["deployment"] ordering_fields = [ "created_at", "updated_at", @@ -237,6 +243,9 @@ def get_serializer_class(self): def get_queryset(self) -> QuerySet: qs: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + qs = qs.filter(project=project) qs = qs.filter(deployment__isnull=False) qs = qs.annotate( duration=models.F("end") - models.F("start"), @@ -363,6 +372,10 @@ def timeline(self, request, pk=None): ) return Response(serializer.data) + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class SourceImageViewSet(DefaultViewSet): """ @@ -549,7 +562,7 @@ class SourceImageCollectionViewSet(DefaultViewSet): ) serializer_class = SourceImageCollectionSerializer - filterset_fields = ["project", "method"] + filterset_fields = ["method"] ordering_fields = [ "created_at", "updated_at", @@ -562,11 +575,14 @@ class SourceImageCollectionViewSet(DefaultViewSet): def get_queryset(self) -> QuerySet: classification_threshold = get_active_classification_threshold(self.request) - queryset = ( - super() - .get_queryset() - .with_occurrences_count(classification_threshold=classification_threshold) # type: ignore - .with_taxa_count(classification_threshold=classification_threshold) + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + queryset = query_set.with_occurrences_count( + classification_threshold=classification_threshold + ).with_taxa_count( # type: ignore + classification_threshold=classification_threshold ) return queryset @@ -647,6 +663,10 @@ def remove(self, request, pk=None): } ) + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class SourceImageUploadViewSet(DefaultViewSet): """ @@ -903,7 +923,6 @@ class OccurrenceViewSet(DefaultViewSet): filterset_fields = [ "event", "deployment", - "project", "determination__rank", "detections__source_image", ] @@ -933,7 +952,10 @@ def get_serializer_class(self): return OccurrenceSerializer def get_queryset(self) -> QuerySet: + project = get_active_project(self.request) qs = super().get_queryset() + if project: + qs = qs.filter(project=project) qs = qs.select_related( "determination", "deployment", @@ -961,6 +983,10 @@ def get_queryset(self) -> QuerySet: return qs + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class TaxonViewSet(DefaultViewSet): """ @@ -1042,23 +1068,22 @@ def filter_taxa_by_observed(self, queryset: QuerySet) -> tuple[QuerySet, bool]: """ occurrence_id = self.request.query_params.get("occurrence") - project_id = self.request.query_params.get("project") or self.request.query_params.get("occurrences__project") + project = get_active_project(self.request) deployment_id = self.request.query_params.get("deployment") or self.request.query_params.get( "occurrences__deployment" ) event_id = self.request.query_params.get("event") or self.request.query_params.get("occurrences__event") collection_id = self.request.query_params.get("collection") - filter_active = any([occurrence_id, project_id, deployment_id, event_id, collection_id]) + filter_active = any([occurrence_id, project, deployment_id, event_id, collection_id]) - if not project_id: + if not project: # Raise a 400 if no project is specified raise api_exceptions.ValidationError(detail="A project must be specified") queryset = super().get_queryset() try: - if project_id: - project = Project.objects.get(id=project_id) + if project: queryset = queryset.filter(occurrences__project=project) if occurrence_id: occurrence = Occurrence.objects.get(id=occurrence_id) @@ -1180,6 +1205,10 @@ def get_queryset(self) -> QuerySet: return qs + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + # def retrieve(self, request: Request, *args, **kwargs) -> Response: # """ # Override the serializer to include the recursive occurrences count @@ -1207,16 +1236,15 @@ class ClassificationViewSet(DefaultViewSet): class SummaryView(GenericAPIView): permission_classes = [IsActiveStaffOrReadOnly] - filterset_fields = ["project"] + @extend_schema(parameters=[project_id_doc_param]) def get(self, request): """ Return counts of all models. """ - project_id = request.query_params.get("project") + project = get_active_project(request) confidence_threshold = get_active_classification_threshold(request) - if project_id: - project = Project.objects.get(id=project_id) + if project: data = { "projects_count": Project.objects.count(), # @TODO filter by current user, here and everywhere! "deployments_count": Deployment.objects.filter(project=project).count(), @@ -1358,13 +1386,24 @@ class SiteViewSet(DefaultViewSet): queryset = Site.objects.all() serializer_class = SiteSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class DeviceViewSet(DefaultViewSet): """ @@ -1373,13 +1412,24 @@ class DeviceViewSet(DefaultViewSet): queryset = Device.objects.all() serializer_class = DeviceSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + class StorageSourceConnectionTestSerializer(serializers.Serializer): subdir = serializers.CharField(required=False, allow_null=True) @@ -1393,13 +1443,23 @@ class StorageSourceViewSet(DefaultViewSet): queryset = S3StorageSource.objects.all() serializer_class = StorageSourceSerializer - filterset_fields = ["project", "deployments"] + filterset_fields = ["deployments"] ordering_fields = [ "created_at", "updated_at", "name", ] + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + query_set = query_set.filter(project=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + @action(detail=True, methods=["post"], name="test", serializer_class=StorageSourceConnectionTestSerializer) def test(self, request: Request, pk=None) -> Response: """ diff --git a/ami/main/models.py b/ami/main/models.py index 2e83890a6..61ace0402 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -115,6 +115,7 @@ class Project(BaseModel): devices: models.QuerySet["Device"] sites: models.QuerySet["Site"] + jobs: models.QuerySet["Job"] def deployments_count(self) -> int: return self.deployments.count() @@ -348,6 +349,7 @@ class Deployment(BaseModel): events: models.QuerySet["Event"] captures: models.QuerySet["SourceImage"] occurrences: models.QuerySet["Occurrence"] + jobs: models.QuerySet["Job"] objects = DeploymentManager() @@ -1199,6 +1201,7 @@ class SourceImage(BaseModel): detections: models.QuerySet["Detection"] collections: models.QuerySet["SourceImageCollection"] + jobs: models.QuerySet["Job"] objects = SourceImageManager() @@ -2745,6 +2748,8 @@ class SourceImageCollection(BaseModel): objects = SourceImageCollectionManager() + jobs: models.QuerySet["Job"] + def source_images_count(self) -> int | None: # This should always be pre-populated using queryset annotations # return self.images.count() diff --git a/ami/main/tests.py b/ami/main/tests.py index d422db060..93f5a87da 100644 --- a/ami/main/tests.py +++ b/ami/main/tests.py @@ -1,12 +1,26 @@ import datetime import logging -from django.db import connection +from django.db import connection, models from django.test import TestCase +from rest_framework import status from rest_framework.test import APIRequestFactory, APITestCase from rich import print -from ami.main.models import Event, Occurrence, Project, Taxon, TaxonRank, group_images_into_events +from ami.main.models import ( + Device, + Event, + Occurrence, + Project, + S3StorageSource, + Site, + SourceImage, + SourceImageCollection, + Taxon, + TaxonRank, + group_images_into_events, +) +from ami.ml.models.pipeline import Pipeline from ami.tests.fixtures.main import create_captures, create_occurrences, create_taxa, setup_test_project from ami.users.models import User @@ -535,7 +549,7 @@ def setUp(self) -> None: def test_occurrences_for_project(self): # Test that occurrences are specific to each project for project in [self.project_one, self.project_two]: - response = self.client.get(f"/api/v2/occurrences/?project={project.pk}") + response = self.client.get(f"/api/v2/occurrences/?project_id={project.pk}") self.assertEqual(response.status_code, 200) self.assertEqual(response.json()["count"], Occurrence.objects.filter(project=project).count()) @@ -578,7 +592,7 @@ def _test_taxa_for_project(self, project: Project): """ from ami.main.models import Taxon - response = self.client.get(f"/api/v2/taxa/?project={project.pk}") + response = self.client.get(f"/api/v2/taxa/?project_id={project.pk}") self.assertEqual(response.status_code, 200) project_occurred_taxa = Taxon.objects.filter(occurrences__project=project).distinct() # project_any_taxa = Taxon.objects.filter(projects=project) @@ -754,3 +768,117 @@ def test_update_subdir(self): self.other_subdir: self.images_per_dir, } self.assertDictEqual(dict(counts), expected_counts) + + +class TestProjectSettingsFiltering(APITestCase): + """Test Project Settings filter by project_id""" + + def setUp(self) -> None: + for _ in range(3): + project, deployment = setup_test_project(reuse=False) + create_taxa(project=project) + create_captures(deployment=deployment) + group_images_into_events(deployment=deployment) + create_occurrences(deployment=deployment, num=5) + self.project_ids = [project.id for project in Project.objects.all()] + + self.user = User.objects.create_user( # type: ignore + email="testuser@insectai.org", + is_staff=True, + ) + self.factory = APIRequestFactory() + self.client.force_authenticate(user=self.user) + return super().setUp() + + def test_project_summary(self): + project_id = self.project_ids[1] + endpoint_url = f"/api/v2/status/summary/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + self.assertEqual(response.status_code, status.HTTP_200_OK) + project = Project.objects.get(pk=project_id) + + self.assertEqual(response_data["deployments_count"], project.deployments_count()) + self.assertEqual( + response_data["taxa_count"], + Taxon.objects.annotate(occurrences_count=models.Count("occurrences")) + .filter( + occurrences_count__gt=0, + occurrences__determination_score__gte=0, + occurrences__project=project, + ) + .distinct() + .count(), + ) + self.assertEqual( + response_data["events_count"], + Event.objects.filter(deployment__project=project, deployment__isnull=False).count(), + ) + self.assertEqual( + response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count() + ) + self.assertEqual( + response_data["occurrences_count"], + Occurrence.objects.filter( + project=project, + determination_score__gte=0, + event__isnull=False, + ).count(), + ) + self.assertEqual( + response_data["captures_count"], SourceImage.objects.filter(deployment__project=project).count() + ) + + def test_project_collections(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/captures/collections/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + expected_project_collection_ids = { + source_image_collection.id + for source_image_collection in SourceImageCollection.objects.filter(project=project) + } + response_source_image_collection_ids = {result.get("id") for result in response_data["results"]} + self.assertEqual(response_source_image_collection_ids, expected_project_collection_ids) + + def test_project_pipelines(self): + project_id = self.project_ids[0] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/ml/pipelines/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + + expected_project_pipeline_ids = {pipeline.id for pipeline in Pipeline.objects.filter(projects=project)} + response_pipeline_ids = {pipeline.get("id") for pipeline in response_data["results"]} + self.assertEqual(response_pipeline_ids, expected_project_pipeline_ids) + + def test_project_storage(self): + project_id = self.project_ids[0] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/storage/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + expected_storage_ids = {storage.id for storage in S3StorageSource.objects.filter(project=project)} + response_storage_ids = {storage.get("id") for storage in response_data["results"]} + self.assertEqual(response_storage_ids, expected_storage_ids) + + def test_project_sites(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/deployments/sites/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + exepcted_site_ids = {site.id for site in Site.objects.filter(project=project)} + response_site_ids = {site.get("id") for site in response_data["results"]} + self.assertEqual(response_site_ids, exepcted_site_ids) + + def test_project_devices(self): + project_id = self.project_ids[1] + project = Project.objects.get(pk=project_id) + endpoint_url = f"/api/v2/deployments/devices/?project_id={project_id}" + response = self.client.get(endpoint_url) + response_data = response.json() + exepcted_device_ids = {device.id for device in Device.objects.filter(project=project)} + response_device_ids = {device.get("id") for device in response_data["results"]} + self.assertEqual(response_device_ids, exepcted_device_ids) diff --git a/ami/ml/migrations/0016_alter_processingservice_options.py b/ami/ml/migrations/0016_alter_processingservice_options.py new file mode 100644 index 000000000..31234393d --- /dev/null +++ b/ami/ml/migrations/0016_alter_processingservice_options.py @@ -0,0 +1,16 @@ +# Generated by Django 4.2.10 on 2025-01-16 20:23 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("ml", "0015_processingservice_delete_backend"), + ] + + operations = [ + migrations.AlterModelOptions( + name="processingservice", + options={"verbose_name": "Processing Service", "verbose_name_plural": "Processing Services"}, + ), + ] diff --git a/ami/ml/views.py b/ami/ml/views.py index 2a6745d71..1f3cd039b 100644 --- a/ami/ml/views.py +++ b/ami/ml/views.py @@ -1,6 +1,8 @@ import logging +from django.db.models.query import QuerySet from django.utils.text import slugify +from drf_spectacular.utils import extend_schema from rest_framework import status from rest_framework.decorators import action from rest_framework.request import Request @@ -8,6 +10,7 @@ from ami.main.api.views import DefaultViewSet from ami.main.models import SourceImage +from ami.utils.requests import get_active_project, project_id_doc_param from .models.algorithm import Algorithm from .models.pipeline import Pipeline @@ -47,6 +50,18 @@ class PipelineViewSet(DefaultViewSet): "created_at", "updated_at", ] + + def get_queryset(self) -> QuerySet: + query_set: QuerySet = super().get_queryset() + project = get_active_project(self.request) + if project: + query_set = query_set.filter(projects=project) + return query_set + + @extend_schema(parameters=[project_id_doc_param]) + def list(self, request, *args, **kwargs): + return super().list(request, *args, **kwargs) + # Don't enable projects filter until we can use the current users # membership to filter the projects. # filterset_fields = ["projects"] @@ -60,6 +75,8 @@ def test_process(self, request: Request, pk=None) -> Response: random_image = ( SourceImage.objects.all().order_by("?").first() ) # TODO: Filter images by projects user has access to + if not random_image: + return Response({"error": "No image found to process."}, status=status.HTTP_404_NOT_FOUND) results = pipeline.process_images(images=[random_image], job_id=None) return Response(results.dict()) diff --git a/ami/tests/fixtures/main.py b/ami/tests/fixtures/main.py index fb70a7652..54035c89d 100644 --- a/ami/tests/fixtures/main.py +++ b/ami/tests/fixtures/main.py @@ -6,6 +6,7 @@ import uuid from django.db import transaction +from django.utils import timezone from ami.main.models import ( Deployment, @@ -70,23 +71,44 @@ def create_processing_service(project): return processing_service +def create_deployment( + project: Project, + data_source, + name="Test Deployment", +) -> Deployment: + """ + Create a test deployment with a data source for source images. + """ + deployment, _ = Deployment.objects.get_or_create( + project=project, + name=name, + defaults=dict( + description=f"Created at {timezone.now()}", + data_source=data_source, + data_source_subdir="/", + data_source_regex=".*\\.jpg", + latitude=45.0, + longitude=-123.0, + research_site=project.sites.first(), + device=project.devices.first(), + ), + ) + return deployment + + def setup_test_project(reuse=True) -> tuple[Project, Deployment]: - if reuse: - short_id = "001" - project, _ = Project.objects.get_or_create(name=f"Test Project {short_id}") - data_source = create_storage_source(project, "Test Data Source") - deployment, _ = Deployment.objects.get_or_create( - project=project, name="Test Deployment", defaults=dict(data_source=data_source) - ) - create_processing_service(project) - else: + project = Project.objects.filter(name__startswith="Test Project").first() + + if not project or not reuse: short_id = uuid.uuid4().hex[:8] project = Project.objects.create(name=f"Test Project {short_id}") data_source = create_storage_source(project, f"Test Data Source {short_id}") - deployment = Deployment.objects.create( - project=project, name=f"Test Deployment {short_id}", data_source=data_source - ) + deployment = create_deployment(project, data_source, f"Test Deployment {short_id}") create_processing_service(project) + else: + deployment = Deployment.objects.filter(project=project).first() + assert deployment, "No deployment found for existing project. Create a new project instead." + return project, deployment diff --git a/ami/utils/requests.py b/ami/utils/requests.py index 832eef19a..50edb0652 100644 --- a/ami/utils/requests.py +++ b/ami/utils/requests.py @@ -1,6 +1,9 @@ from django.forms import FloatField +from drf_spectacular.utils import OpenApiParameter from rest_framework.request import Request +from ami.main.models import Project + def get_active_classification_threshold(request: Request) -> float: # Look for a query param to filter by score @@ -11,3 +14,18 @@ def get_active_classification_threshold(request: Request) -> float: else: classification_threshold = 0 return classification_threshold + + +def get_active_project(request: Request) -> Project | None: + project_id = request.query_params.get("project_id") + if project_id: + return Project.objects.filter(id=project_id).first() + return None + + +project_id_doc_param = OpenApiParameter( + name="project_id", + description="Filter by project ID", + required=False, + type=int, +) diff --git a/ui/src/components/error-state/error-state.tsx b/ui/src/components/error-state/error-state.tsx index f5b39b00b..44b078f83 100644 --- a/ui/src/components/error-state/error-state.tsx +++ b/ui/src/components/error-state/error-state.tsx @@ -10,7 +10,8 @@ export const ErrorState = ({ error }: ErrorStateProps) => { const data = error?.response?.data const description = useMemo(() => { - const entries = data ? Object.entries(data) : undefined + const entries = + data && typeof data === 'object' ? Object.entries(data) : undefined if (entries?.length) { const [key, value] = entries[0] diff --git a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss index a73d0a3a8..b7b9a6c3c 100644 --- a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss +++ b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.module.scss @@ -4,6 +4,7 @@ .miniForm { border-radius: 8px; border: 1px solid $color-neutral-100; + overflow: hidden; } .miniFormContent { diff --git a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx index d3792fffc..d0f02e0ca 100644 --- a/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx +++ b/ui/src/components/header/user-info-dialog/user-info-form/user-info-form.tsx @@ -11,6 +11,7 @@ import { useUpdateUserInfo } from 'data-services/hooks/auth/useUpdateUserInfo' import { Button, ButtonTheme } from 'design-system/components/button/button' import { IconType } from 'design-system/components/icon/icon' import { InputContent } from 'design-system/components/input/input' +import { useRef } from 'react' import { useForm } from 'react-hook-form' import { API_MAX_UPLOAD_SIZE } from 'utils/constants' import { STRING, translate } from 'utils/language' @@ -55,6 +56,7 @@ const config: FormConfig = { } export const UserInfoForm = ({ userInfo }: { userInfo: UserInfo }) => { + const formRef = useRef(null) const { control, handleSubmit, @@ -68,29 +70,29 @@ export const UserInfoForm = ({ userInfo }: { userInfo: UserInfo }) => { const errorMessage = useFormError({ error, setFieldError }) return ( -
updateUserInfo(values))}> - {errorMessage && ( - - )} + <> + {errorMessage && ( + + )} - - <> - + updateUserInfo(values))} + className="grid gap-8" + > - - { )} /> - - + +