From 39d5d000d963745df4f46ecd031e5eb323e9868a Mon Sep 17 00:00:00 2001 From: Luka Peschke Date: Fri, 10 Jan 2025 10:35:16 +0100 Subject: [PATCH] fix(server): count null values as well in `count` aggregation [TCTC-9967] (#2315) * fix(server/pandas): count null values as well in count aggreggation [TCTC-9967] * fix(server/pypika): count null values as well in count distinct aggregation * fix: enforce ASC NULLS LAST for aggregate step ORDER BY --------- Signed-off-by: Luka Peschke --- server/CHANGELOG.md | 4 ++ .../pandas_executor/steps/aggregate.py | 13 +++-- .../pypika_translator/translators/base.py | 32 +++++++++---- .../fixtures/aggregate/count_nulls.yaml | 47 +++++++++++++++++++ .../aggregate/count_nulls_pypika.yaml | 41 ++++++++++++++++ .../test_base_translator.py | 24 +++++++--- .../test_base_translator_strings.py | 6 +-- 7 files changed, 143 insertions(+), 24 deletions(-) create mode 100644 server/tests/backends/fixtures/aggregate/count_nulls.yaml create mode 100644 server/tests/backends/fixtures/aggregate/count_nulls_pypika.yaml diff --git a/server/CHANGELOG.md b/server/CHANGELOG.md index 0cede43bc9..f51386f700 100644 --- a/server/CHANGELOG.md +++ b/server/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Fixed + +- Pandas & Pypika: the `count` aggregation of the aggregate step now properly counts nulls + ## [0.48.6] - 2024-12-11 ### Fixed diff --git a/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py b/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py index 16f8af10c4..c4372327b9 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/aggregate.py @@ -1,6 +1,6 @@ from typing import Any, Literal -from pandas import DataFrame, concat +from pandas import DataFrame, Series, concat from weaverbird.backends.pandas_executor.types import DomainRetriever, PipelineExecutor from weaverbird.pipeline.steps import AggregateStep @@ -10,15 +10,20 @@ "sum", "min", "max", - "count", "count distinct", "first", "last", "count distinct including empty", ] + +def _count(series: Series) -> int: + return series.size + + functions_aliases = { "avg": "mean", + "count": _count, "count distinct": "nunique", "count distinct including empty": len, } @@ -33,8 +38,8 @@ def get_aggregate_fn(agg_function: str) -> Any: def execute_aggregate( step: AggregateStep, df: DataFrame, - domain_retriever: DomainRetriever = None, - execute_pipeline: PipelineExecutor = None, + domain_retriever: DomainRetriever | None = None, + execute_pipeline: PipelineExecutor | None = None, ) -> DataFrame: group_by_columns = step.on diff --git a/server/src/weaverbird/backends/pypika_translator/translators/base.py b/server/src/weaverbird/backends/pypika_translator/translators/base.py index 3330f29786..cb42d269d9 100644 --- a/server/src/weaverbird/backends/pypika_translator/translators/base.py +++ b/server/src/weaverbird/backends/pypika_translator/translators/base.py @@ -2,6 +2,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from datetime import UTC, date, datetime +from enum import Enum from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, get_args from uuid import uuid4 @@ -11,7 +12,6 @@ Case, Criterion, Field, - Order, Query, Schema, Table, @@ -223,6 +223,12 @@ def __init__(self, date_part, interval, term, alias=None): super().__init__("DATEADD", LiteralValue(date_part), interval, term, alias=alias) +class Order(str, Enum): + ASC = "ASC" + DESC = "DESC" + ASC_NULLS_LAST = "ASC NULLS LAST" + + class SQLTranslator(ABC): DIALECT: SQLDialect QUERY_CLS: Query @@ -565,7 +571,11 @@ def _build_window_subquery() -> Any: for agg_column_name, new_column_name in zip(aggregation.columns, aggregation.new_columns, strict=True): if new_column_name not in agg_col_names: column_field = prev_step_table[agg_column_name] - new_agg_col = agg_fn(column_field).as_(new_column_name) + # Count("column") ignores NULL values, whereas COUNT(*) takes them into account + if agg_fn is functions.Count: + new_agg_col = agg_fn("*").as_(new_column_name) + else: + new_agg_col = agg_fn(column_field).as_(new_column_name) agg_selected.append(new_agg_col) agg_col_names.append(new_column_name) @@ -598,7 +608,7 @@ def _build_window_subquery() -> Any: window_table = Table("window_subquery") all_windows_subquery = _build_window_subquery() agg_query = ( - prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.asc) + prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.ASC) ).as_("agg_subquery") agg_table = Table("agg_subquery") merged_selected: list[str | Field] = [ @@ -648,7 +658,9 @@ def _build_window_subquery() -> Any: ) ) selected_col_names = [*columns, *agg_col_names] - return StepContext(query.orderby(*step.on) if step.on else query, selected_col_names) + return StepContext( + query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else query, selected_col_names + ) else: selected_col_names = [ @@ -657,7 +669,7 @@ def _build_window_subquery() -> Any: *(f[1].alias for f in window_selected), ] return StepContext( - merged_query.orderby(*step.on) if step.on else merged_query, + merged_query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else merged_query, selected_col_names, ) @@ -1625,7 +1637,7 @@ def rank( analytics_fn = analytics.Rank if step.method == "standard" else analytics.DenseRank rank_column = ( (analytics_fn().over(*(Field(group) for group in step.groupby)) if step.groupby else analytics_fn()) - .orderby(col_field, order=Order.desc if step.order == "desc" else Order.asc) + .orderby(col_field, order=Order.DESC if step.order == "desc" else Order.ASC) .as_(new_col_name) ) @@ -1711,7 +1723,7 @@ def sort( query = prev_step_table.select(*columns) for column_sort in step.columns: - query = query.orderby(column_sort.column, order=Order.desc if column_sort.order == "desc" else Order.asc) + query = query.orderby(column_sort.column, order=Order.DESC if column_sort.order == "desc" else Order.ASC) return StepContext(query, columns) @@ -1853,7 +1865,7 @@ def top( sub_query = sub_query.select( RowNumber() .over(*groups_fields) - .orderby(rank_on_field, order=Order.desc if step.sort == "desc" else Order.asc) + .orderby(rank_on_field, order=Order.DESC if step.sort == "desc" else Order.ASC) .as_("row_number") ) query: QueryBuilder = ( @@ -1862,7 +1874,7 @@ def top( .where(Field("row_number") <= step.limit) # The order of returned results is not necessarily consistent. This ensures we # always get the results in the same order - .orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.asc) + .orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.ASC) ) return StepContext(query, columns) @@ -1871,7 +1883,7 @@ def top( query = ( prev_step_table.select(*columns) - .orderby(step.rank_on, order=Order.desc if step.sort == "desc" else Order.asc) + .orderby(step.rank_on, order=Order.DESC if step.sort == "desc" else Order.ASC) .limit(step.limit) ) return StepContext(query, columns) diff --git a/server/tests/backends/fixtures/aggregate/count_nulls.yaml b/server/tests/backends/fixtures/aggregate/count_nulls.yaml new file mode 100644 index 0000000000..ff84a80a02 --- /dev/null +++ b/server/tests/backends/fixtures/aggregate/count_nulls.yaml @@ -0,0 +1,47 @@ +exclude: +- athena_pypika +- bigquery_pypika +- mysql_pypika +- postgres_pypika +- redshift_pypika +- snowflake_pypika +step: + pipeline: + - aggregations: + - aggfunction: count + columns: + - VALUE + newcolumns: + - VALUE_COUNT + keepOriginalGranularity: false + name: aggregate + 'on': + - VALUE +input: + data: + - VALUE: one + - VALUE: two + - VALUE: null + - VALUE: one + - VALUE: null + - VALUE: one + schema: + fields: + - name: VALUE + type: string + pandas_version: 1.4.0 +expected: + data: + - VALUE: one + VALUE_COUNT: 3 + - VALUE: two + VALUE_COUNT: 1 + - VALUE: null + VALUE_COUNT: 2 + schema: + fields: + - name: VALUE + type: string + - name: VALUE_COUNT + type: integer + pandas_version: 1.4.0 diff --git a/server/tests/backends/fixtures/aggregate/count_nulls_pypika.yaml b/server/tests/backends/fixtures/aggregate/count_nulls_pypika.yaml new file mode 100644 index 0000000000..55e5116001 --- /dev/null +++ b/server/tests/backends/fixtures/aggregate/count_nulls_pypika.yaml @@ -0,0 +1,41 @@ +exclude: + - mongo + - pandas + - snowflake +step: + pipeline: + - aggregations: + - aggfunction: count + columns: + - nullable_name + newcolumns: + - nullable_name_count + keepOriginalGranularity: false + name: aggregate + 'on': + - nullable_name +expected: + data: + - nullable_name: Ardwen Blonde + nullable_name_count: 1 + - nullable_name: Bellfield Lawless Village IPA + nullable_name_count: 1 + - nullable_name: Brewdog Nanny State Alcoholvrij + nullable_name_count: 1 + - nullable_name: Brugse Zot blonde + nullable_name_count: 1 + - nullable_name: Ninkasi Ploploplop + nullable_name_count: 1 + - nullable_name: Pauwel Kwak + nullable_name_count: 1 + - nullable_name: Weihenstephan Hefe Weizen Alcoholarm + nullable_name_count: 1 + - nullable_name: null + nullable_name_count: 3 + schema: + fields: + - name: nullable_name + type: string + - name: nullable_name_count + type: integer + pandas_version: 1.4.0 diff --git a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py index 1d07e68a1e..695a0c1314 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_base_translator.py +++ b/server/tests/backends/sql_translator_unit_tests/test_base_translator.py @@ -3,13 +3,13 @@ from zoneinfo import ZoneInfo import pytest -from pypika import AliasedQuery, Case, Field, Order, Query, Schema, Table, analytics, functions +from pypika import AliasedQuery, Case, Field, Query, Schema, Table, analytics, functions from pypika.enums import JoinType from pypika.queries import QueryBuilder from pypika.terms import LiteralValue, Term, ValueWrapper from pytest_mock import MockFixture -from weaverbird.backends.pypika_translator.translators.base import DataTypeMapping, SQLTranslator +from weaverbird.backends.pypika_translator.translators.base import DataTypeMapping, Order, SQLTranslator from weaverbird.backends.pypika_translator.translators.exceptions import ( ForbiddenSQLStep, UnknownTableColumns, @@ -198,8 +198,12 @@ def test_aggregate(base_translator: BaseTranslator, agg_type: str, default_step_ agg_func = base_translator._get_aggregate_function(agg_type) field = Field(agg_field) + agged = agg_func(field) if agg_func is not functions.Count else functions.Count("*") expected_query = ( - Query.from_(previous_step).groupby(field).orderby(agg_field).select(field, agg_func(field).as_(new_column)) + Query.from_(previous_step) + .groupby(field) + .orderby(agg_field, order=Order.ASC_NULLS_LAST) + .select(field, agged.as_(new_column)) ) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -223,7 +227,8 @@ def test_aggregate_with_original_granularity( agg_func = base_translator._get_aggregate_function(agg_type) field = Field(agg_field) - agg_query = Query.from_(previous_step).groupby(field).select(field, agg_func(field).as_(new_column)) + agged = agg_func(field) if agg_func is not functions.Count else functions.Count("*") + agg_query = Query.from_(previous_step).groupby(field).select(field, agged.as_(new_column)) expected_query = ( Query.from_(previous_step) @@ -231,7 +236,7 @@ def test_aggregate_with_original_granularity( .left_join(agg_query) .on_field(agg_field) .select(*original_select) - .orderby(agg_field) + .orderby(agg_field, order=Order.ASC_NULLS_LAST) ) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -544,7 +549,7 @@ def test_sort(base_translator: BaseTranslator, default_step_kwargs: dict[str, An step = steps.SortStep(columns=columns) ctx = base_translator.sort(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.asc) + expected_query = Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.ASC) assert ctx.selectable.get_sql() == expected_query.get_sql() @@ -623,7 +628,12 @@ def test_uniquegroups(base_translator: BaseTranslator, default_step_kwargs: dict step = steps.UniqueGroupsStep(on=columns) ctx = base_translator.uniquegroups(step=step, columns=selected_columns, **default_step_kwargs) - expected_query = Query.from_(previous_step).select(Field(column)).groupby(Field(column)).orderby(Field(column)) + expected_query = ( + Query.from_(previous_step) + .select(Field(column)) + .groupby(Field(column)) + .orderby(Field(column), order=Order.ASC_NULLS_LAST) + ) assert ctx.selectable.get_sql() == expected_query.get_sql() diff --git a/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py b/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py index bdec2cdab2..671ab8e959 100644 --- a/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py +++ b/server/tests/backends/sql_translator_unit_tests/test_base_translator_strings.py @@ -18,7 +18,7 @@ aggregations=[{"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]}], ), ], - 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT("name") "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind"', + 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT(*) "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind" ASC NULLS LAST', ), ( [ @@ -29,7 +29,7 @@ keep_original_granularity=True, ), ], - 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "__step_0_dummy__"."price_per_l","__step_0_dummy__"."alcohol_degree","__step_0_dummy__"."name","__step_0_dummy__"."cost","__step_0_dummy__"."beer_kind","__step_0_dummy__"."volume_ml","__step_0_dummy__"."brewing_date","__step_0_dummy__"."nullable_name","sq0"."beer_count" FROM "__step_0_dummy__" LEFT JOIN (WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT("name") "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind") "sq0" ON "__step_0_dummy__"."beer_kind"="sq0"."beer_kind" ORDER BY "__step_0_dummy__"."beer_kind"', + 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "__step_0_dummy__"."price_per_l","__step_0_dummy__"."alcohol_degree","__step_0_dummy__"."name","__step_0_dummy__"."cost","__step_0_dummy__"."beer_kind","__step_0_dummy__"."volume_ml","__step_0_dummy__"."brewing_date","__step_0_dummy__"."nullable_name","sq0"."beer_count" FROM "__step_0_dummy__" LEFT JOIN (WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT(*) "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind") "sq0" ON "__step_0_dummy__"."beer_kind"="sq0"."beer_kind" ORDER BY "__step_0_dummy__"."beer_kind" ASC NULLS LAST', ), ( [ @@ -47,7 +47,7 @@ ), steps.AbsoluteValueStep(column="avg_price_per_l", new_column="avg_price_per_l_abs"), ], - 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") ,__step_1_dummy__ AS (SELECT "beer_kind",COUNT("name") "beer_count",AVG("price_per_l") "avg_price_per_l" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind") SELECT "beer_kind","beer_count","avg_price_per_l",ABS("avg_price_per_l") "avg_price_per_l_abs" FROM "__step_1_dummy__"', + 'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") ,__step_1_dummy__ AS (SELECT "beer_kind",COUNT(*) "beer_count",AVG("price_per_l") "avg_price_per_l" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind" ASC NULLS LAST) SELECT "beer_kind","beer_count","avg_price_per_l",ABS("avg_price_per_l") "avg_price_per_l_abs" FROM "__step_1_dummy__"', ), ( [