From 66f94d64ced0f805d9445fa422982eb392223b46 Mon Sep 17 00:00:00 2001 From: abondar Date: Wed, 24 Apr 2024 19:43:58 +0300 Subject: [PATCH] Fix annotation propagation for non-filter queries (#1590) --- poetry.lock | 2 +- pyproject.toml | 3 +- tests/test_aggregation.py | 78 ++++++++++++++++++++++++++++++++++++++- tortoise/queryset.py | 20 ++++++---- 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/poetry.lock b/poetry.lock index 20c48b26f..99816c91f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3497,4 +3497,4 @@ psycopg = ["psycopg"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "299e5a492a12bdad51cd7708f9be04a786993b3d3ae3fbfffe06b666e90328f9" +content-hash = "9b882705ce010208b418c97d944bdb75dac3592b7f3acf69f01710cb38de63c2" diff --git a/pyproject.toml b/pyproject.toml index 7d7ace03d..5cfde39ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ aiomysql = { version = "*", optional = true } asyncmy = { version = "^0.2.8", optional = true, allow-prereleases = true } psycopg = { extras = ["pool", "binary"], version = "^3.0.12", optional = true } asyncodbc = { version = "^0.1.1", optional = true } +pydantic = { version = "^2.0,!=2.7.0", optional = true } [tool.poetry.dev-dependencies] # Linter tools @@ -72,7 +73,7 @@ sanic = "*" # Sample integration - Starlette starlette = "*" # Pydantic support -pydantic = "^2.0" +pydantic = "^2.0,!=2.7.0" # FastAPI support fastapi = "^0.100.0" asgi_lifespan = "*" diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index f8e62319d..302f53493 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -1,3 +1,5 @@ +import pytest + from tests.testmodels import Author, Book, Event, MinRelation, Team, Tournament from tortoise.contrib import test from tortoise.contrib.test.condition import In @@ -46,7 +48,8 @@ async def test_aggregation(self): await Event.all().annotate(tournament_test_id=Sum("tournament__id")).first() ) self.assertEqual( - event_with_annotation.tournament_test_id, event_with_annotation.tournament_id + event_with_annotation.tournament_test_id, + event_with_annotation.tournament_id, ) with self.assertRaisesRegex(ConfigurationError, "name__id not resolvable"): @@ -162,3 +165,76 @@ async def test_concat_functions(self): .values("long_info") ) self.assertEqual(ret, [{"long_info": "Physics Book(physics)"}]) + + async def test_count_after_aggregate(self): + author = await Author.create(name="1") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + author2 = await Author.create(name="2") + await Book.create(name="F-2", author=author2, rating=3) + await Book.create(name="F-3", author=author2, rating=3) + + author3 = await Author.create(name="3") + await Book.create(name="F-4", author=author3, rating=3) + await Book.create(name="F-5", author=author3, rating=2) + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=3) + .count() + ) + + assert ret == 2 + + async def test_exist_after_aggregate(self): + author = await Author.create(name="1") + await Book.create(name="First!", author=author, rating=4) + await Book.create(name="Second!", author=author, rating=3) + await Book.create(name="Third!", author=author, rating=3) + + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=3) + .exists() + ) + + assert ret is True + + ret = ( + await Author.all() + .annotate(average_rating=Avg("books__rating")) + .filter(average_rating__gte=4) + .exists() + ) + assert ret is False + + async def test_count_after_aggregate_m2m(self): + tournament = await Tournament.create(name="1") + event1 = await Event.create(name="First!", tournament=tournament) + event2 = await Event.create(name="Second!", tournament=tournament) + event3 = await Event.create(name="Third!", tournament=tournament) + event4 = await Event.create(name="Fourth!", tournament=tournament) + + team1 = await Team.create(name="1") + team2 = await Team.create(name="2") + team3 = await Team.create(name="3") + + await event1.participants.add(team1, team2, team3) + await event2.participants.add(team1, team2) + await event3.participants.add(team1) + await event4.participants.add(team1, team2, team3) + + query = ( + Event.filter(participants__id__in=[team1.id, team2.id, team3.id]) + .annotate(count=Count("event_id")) + .filter(count=3) + .prefetch_related("participants") + ) + result = await query + assert len(result) == 2 + + res = await query.count() + assert res == 2 diff --git a/tortoise/queryset.py b/tortoise/queryset.py index b1957307d..074e7868d 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -21,7 +21,8 @@ ) from pypika import JoinType, Order, Table -from pypika.functions import Cast, Count +from pypika.analytics import Count +from pypika.functions import Cast from pypika.queries import QueryBuilder from pypika.terms import Case, Field, Term, ValueWrapper from typing_extensions import Literal, Protocol @@ -131,7 +132,7 @@ def resolve_filters( :param annotations: Extra annotations to add. :param custom_filters: Pre-resolved filters to be passed through. """ - has_aggregate = self._resolve_annotate() + has_aggregate = self._resolve_annotate(annotations) modifier = QueryModifier() for node in q_objects: @@ -236,13 +237,14 @@ def resolve_ordering( self.query = self.query.orderby(field, order=ordering[1]) - def _resolve_annotate(self) -> bool: - if not self._annotations: + def _resolve_annotate(self, extra_annotations: Dict[str, Any]) -> bool: + if not self._annotations and not extra_annotations: return False table = self.model._meta.basetable + all_annotations = {**self._annotations, **extra_annotations} annotation_info = {} - for key, annotation in self._annotations.items(): + for key, annotation in all_annotations.items(): if isinstance(annotation, Term): annotation_info[key] = {"joins": [], "field": annotation} else: @@ -251,7 +253,8 @@ def _resolve_annotate(self) -> bool: for key, info in annotation_info.items(): for join in info["joins"]: self._join_table_by_field(*join) - self.query._select_other(info["field"].as_(key)) + if key in self._annotations: + self.query._select_other(info["field"].as_(key)) return any(info["field"].is_aggregate for info in annotation_info.values()) @@ -1282,7 +1285,10 @@ def _make_query(self) -> None: annotations=self.annotations, custom_filters=self.custom_filters, ) - self.query._select_other(Count("*")) + count_term = Count("*") + if self.query._groupbys: + count_term = count_term.over() + self.query._select_other(count_term) if self.force_indexes: self.query._force_indexes = []