Skip to content

Commit

Permalink
fix(server): count null values as well in count aggregation [TCTC-9…
Browse files Browse the repository at this point in the history
…967] (#2315)

* fix(server/pandas): count null values as well in count aggreggation [TCTC-9967]
* fix(server/pypika): count null values as well in count distinct aggregation
* fix: enforce ASC NULLS LAST for aggregate step ORDER BY

---------

Signed-off-by: Luka Peschke <[email protected]>
  • Loading branch information
lukapeschke authored Jan 10, 2025
1 parent 031f919 commit 39d5d00
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 24 deletions.
4 changes: 4 additions & 0 deletions server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Fixed

- Pandas & Pypika: the `count` aggregation of the aggregate step now properly counts nulls

## [0.48.6] - 2024-12-11

### Fixed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pandas import DataFrame, concat
from pandas import DataFrame, Series, concat

from weaverbird.backends.pandas_executor.types import DomainRetriever, PipelineExecutor
from weaverbird.pipeline.steps import AggregateStep
Expand All @@ -10,15 +10,20 @@
"sum",
"min",
"max",
"count",
"count distinct",
"first",
"last",
"count distinct including empty",
]


def _count(series: Series) -> int:
return series.size


functions_aliases = {
"avg": "mean",
"count": _count,
"count distinct": "nunique",
"count distinct including empty": len,
}
Expand All @@ -33,8 +38,8 @@ def get_aggregate_fn(agg_function: str) -> Any:
def execute_aggregate(
step: AggregateStep,
df: DataFrame,
domain_retriever: DomainRetriever = None,
execute_pipeline: PipelineExecutor = None,
domain_retriever: DomainRetriever | None = None,
execute_pipeline: PipelineExecutor | None = None,
) -> DataFrame:
group_by_columns = step.on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from datetime import UTC, date, datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, cast, get_args
from uuid import uuid4

Expand All @@ -11,7 +12,6 @@
Case,
Criterion,
Field,
Order,
Query,
Schema,
Table,
Expand Down Expand Up @@ -223,6 +223,12 @@ def __init__(self, date_part, interval, term, alias=None):
super().__init__("DATEADD", LiteralValue(date_part), interval, term, alias=alias)


class Order(str, Enum):
ASC = "ASC"
DESC = "DESC"
ASC_NULLS_LAST = "ASC NULLS LAST"


class SQLTranslator(ABC):
DIALECT: SQLDialect
QUERY_CLS: Query
Expand Down Expand Up @@ -565,7 +571,11 @@ def _build_window_subquery() -> Any:
for agg_column_name, new_column_name in zip(aggregation.columns, aggregation.new_columns, strict=True):
if new_column_name not in agg_col_names:
column_field = prev_step_table[agg_column_name]
new_agg_col = agg_fn(column_field).as_(new_column_name)
# Count("column") ignores NULL values, whereas COUNT(*) takes them into account
if agg_fn is functions.Count:
new_agg_col = agg_fn("*").as_(new_column_name)
else:
new_agg_col = agg_fn(column_field).as_(new_column_name)
agg_selected.append(new_agg_col)
agg_col_names.append(new_column_name)

Expand Down Expand Up @@ -598,7 +608,7 @@ def _build_window_subquery() -> Any:
window_table = Table("window_subquery")
all_windows_subquery = _build_window_subquery()
agg_query = (
prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.asc)
prev_step_table.select(*agg_selected, *step.on).groupby(*step.on).orderby(*step.on, order=Order.ASC)
).as_("agg_subquery")
agg_table = Table("agg_subquery")
merged_selected: list[str | Field] = [
Expand Down Expand Up @@ -648,7 +658,9 @@ def _build_window_subquery() -> Any:
)
)
selected_col_names = [*columns, *agg_col_names]
return StepContext(query.orderby(*step.on) if step.on else query, selected_col_names)
return StepContext(
query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else query, selected_col_names
)

else:
selected_col_names = [
Expand All @@ -657,7 +669,7 @@ def _build_window_subquery() -> Any:
*(f[1].alias for f in window_selected),
]
return StepContext(
merged_query.orderby(*step.on) if step.on else merged_query,
merged_query.orderby(*step.on, order=Order.ASC_NULLS_LAST) if step.on else merged_query,
selected_col_names,
)

Expand Down Expand Up @@ -1625,7 +1637,7 @@ def rank(
analytics_fn = analytics.Rank if step.method == "standard" else analytics.DenseRank
rank_column = (
(analytics_fn().over(*(Field(group) for group in step.groupby)) if step.groupby else analytics_fn())
.orderby(col_field, order=Order.desc if step.order == "desc" else Order.asc)
.orderby(col_field, order=Order.DESC if step.order == "desc" else Order.ASC)
.as_(new_col_name)
)

Expand Down Expand Up @@ -1711,7 +1723,7 @@ def sort(
query = prev_step_table.select(*columns)

for column_sort in step.columns:
query = query.orderby(column_sort.column, order=Order.desc if column_sort.order == "desc" else Order.asc)
query = query.orderby(column_sort.column, order=Order.DESC if column_sort.order == "desc" else Order.ASC)

return StepContext(query, columns)

Expand Down Expand Up @@ -1853,7 +1865,7 @@ def top(
sub_query = sub_query.select(
RowNumber()
.over(*groups_fields)
.orderby(rank_on_field, order=Order.desc if step.sort == "desc" else Order.asc)
.orderby(rank_on_field, order=Order.DESC if step.sort == "desc" else Order.ASC)
.as_("row_number")
)
query: QueryBuilder = (
Expand All @@ -1862,7 +1874,7 @@ def top(
.where(Field("row_number") <= step.limit)
# The order of returned results is not necessarily consistent. This ensures we
# always get the results in the same order
.orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.asc)
.orderby(*(Field(f) for f in step.groups + ["row_number"]), order=Order.ASC)
)
return StepContext(query, columns)

Expand All @@ -1871,7 +1883,7 @@ def top(

query = (
prev_step_table.select(*columns)
.orderby(step.rank_on, order=Order.desc if step.sort == "desc" else Order.asc)
.orderby(step.rank_on, order=Order.DESC if step.sort == "desc" else Order.ASC)
.limit(step.limit)
)
return StepContext(query, columns)
Expand Down
47 changes: 47 additions & 0 deletions server/tests/backends/fixtures/aggregate/count_nulls.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
exclude:
- athena_pypika
- bigquery_pypika
- mysql_pypika
- postgres_pypika
- redshift_pypika
- snowflake_pypika
step:
pipeline:
- aggregations:
- aggfunction: count
columns:
- VALUE
newcolumns:
- VALUE_COUNT
keepOriginalGranularity: false
name: aggregate
'on':
- VALUE
input:
data:
- VALUE: one
- VALUE: two
- VALUE: null
- VALUE: one
- VALUE: null
- VALUE: one
schema:
fields:
- name: VALUE
type: string
pandas_version: 1.4.0
expected:
data:
- VALUE: one
VALUE_COUNT: 3
- VALUE: two
VALUE_COUNT: 1
- VALUE: null
VALUE_COUNT: 2
schema:
fields:
- name: VALUE
type: string
- name: VALUE_COUNT
type: integer
pandas_version: 1.4.0
41 changes: 41 additions & 0 deletions server/tests/backends/fixtures/aggregate/count_nulls_pypika.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
exclude:
- mongo
- pandas
- snowflake
step:
pipeline:
- aggregations:
- aggfunction: count
columns:
- nullable_name
newcolumns:
- nullable_name_count
keepOriginalGranularity: false
name: aggregate
'on':
- nullable_name
expected:
data:
- nullable_name: Ardwen Blonde
nullable_name_count: 1
- nullable_name: Bellfield Lawless Village IPA
nullable_name_count: 1
- nullable_name: Brewdog Nanny State Alcoholvrij
nullable_name_count: 1
- nullable_name: Brugse Zot blonde
nullable_name_count: 1
- nullable_name: Ninkasi Ploploplop
nullable_name_count: 1
- nullable_name: Pauwel Kwak
nullable_name_count: 1
- nullable_name: Weihenstephan Hefe Weizen Alcoholarm
nullable_name_count: 1
- nullable_name: null
nullable_name_count: 3
schema:
fields:
- name: nullable_name
type: string
- name: nullable_name_count
type: integer
pandas_version: 1.4.0
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from zoneinfo import ZoneInfo

import pytest
from pypika import AliasedQuery, Case, Field, Order, Query, Schema, Table, analytics, functions
from pypika import AliasedQuery, Case, Field, Query, Schema, Table, analytics, functions
from pypika.enums import JoinType
from pypika.queries import QueryBuilder
from pypika.terms import LiteralValue, Term, ValueWrapper
from pytest_mock import MockFixture

from weaverbird.backends.pypika_translator.translators.base import DataTypeMapping, SQLTranslator
from weaverbird.backends.pypika_translator.translators.base import DataTypeMapping, Order, SQLTranslator
from weaverbird.backends.pypika_translator.translators.exceptions import (
ForbiddenSQLStep,
UnknownTableColumns,
Expand Down Expand Up @@ -198,8 +198,12 @@ def test_aggregate(base_translator: BaseTranslator, agg_type: str, default_step_

agg_func = base_translator._get_aggregate_function(agg_type)
field = Field(agg_field)
agged = agg_func(field) if agg_func is not functions.Count else functions.Count("*")
expected_query = (
Query.from_(previous_step).groupby(field).orderby(agg_field).select(field, agg_func(field).as_(new_column))
Query.from_(previous_step)
.groupby(field)
.orderby(agg_field, order=Order.ASC_NULLS_LAST)
.select(field, agged.as_(new_column))
)

assert ctx.selectable.get_sql() == expected_query.get_sql()
Expand All @@ -223,15 +227,16 @@ def test_aggregate_with_original_granularity(

agg_func = base_translator._get_aggregate_function(agg_type)
field = Field(agg_field)
agg_query = Query.from_(previous_step).groupby(field).select(field, agg_func(field).as_(new_column))
agged = agg_func(field) if agg_func is not functions.Count else functions.Count("*")
agg_query = Query.from_(previous_step).groupby(field).select(field, agged.as_(new_column))

expected_query = (
Query.from_(previous_step)
.select(*original_select)
.left_join(agg_query)
.on_field(agg_field)
.select(*original_select)
.orderby(agg_field)
.orderby(agg_field, order=Order.ASC_NULLS_LAST)
)

assert ctx.selectable.get_sql() == expected_query.get_sql()
Expand Down Expand Up @@ -544,7 +549,7 @@ def test_sort(base_translator: BaseTranslator, default_step_kwargs: dict[str, An
step = steps.SortStep(columns=columns)
ctx = base_translator.sort(step=step, columns=selected_columns, **default_step_kwargs)

expected_query = Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.asc)
expected_query = Query.from_(previous_step).select(*selected_columns).orderby(Field("name"), order=Order.ASC)

assert ctx.selectable.get_sql() == expected_query.get_sql()

Expand Down Expand Up @@ -623,7 +628,12 @@ def test_uniquegroups(base_translator: BaseTranslator, default_step_kwargs: dict
step = steps.UniqueGroupsStep(on=columns)
ctx = base_translator.uniquegroups(step=step, columns=selected_columns, **default_step_kwargs)

expected_query = Query.from_(previous_step).select(Field(column)).groupby(Field(column)).orderby(Field(column))
expected_query = (
Query.from_(previous_step)
.select(Field(column))
.groupby(Field(column))
.orderby(Field(column), order=Order.ASC_NULLS_LAST)
)

assert ctx.selectable.get_sql() == expected_query.get_sql()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
aggregations=[{"aggfunction": "count", "new_columns": ["beer_count"], "columns": ["name"]}],
),
],
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT("name") "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind"',
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT(*) "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind" ASC NULLS LAST',
),
(
[
Expand All @@ -29,7 +29,7 @@
keep_original_granularity=True,
),
],
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "__step_0_dummy__"."price_per_l","__step_0_dummy__"."alcohol_degree","__step_0_dummy__"."name","__step_0_dummy__"."cost","__step_0_dummy__"."beer_kind","__step_0_dummy__"."volume_ml","__step_0_dummy__"."brewing_date","__step_0_dummy__"."nullable_name","sq0"."beer_count" FROM "__step_0_dummy__" LEFT JOIN (WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT("name") "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind") "sq0" ON "__step_0_dummy__"."beer_kind"="sq0"."beer_kind" ORDER BY "__step_0_dummy__"."beer_kind"',
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "__step_0_dummy__"."price_per_l","__step_0_dummy__"."alcohol_degree","__step_0_dummy__"."name","__step_0_dummy__"."cost","__step_0_dummy__"."beer_kind","__step_0_dummy__"."volume_ml","__step_0_dummy__"."brewing_date","__step_0_dummy__"."nullable_name","sq0"."beer_count" FROM "__step_0_dummy__" LEFT JOIN (WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") SELECT "beer_kind",COUNT(*) "beer_count" FROM "__step_0_dummy__" GROUP BY "beer_kind") "sq0" ON "__step_0_dummy__"."beer_kind"="sq0"."beer_kind" ORDER BY "__step_0_dummy__"."beer_kind" ASC NULLS LAST',
),
(
[
Expand All @@ -47,7 +47,7 @@
),
steps.AbsoluteValueStep(column="avg_price_per_l", new_column="avg_price_per_l_abs"),
],
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") ,__step_1_dummy__ AS (SELECT "beer_kind",COUNT("name") "beer_count",AVG("price_per_l") "avg_price_per_l" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind") SELECT "beer_kind","beer_count","avg_price_per_l",ABS("avg_price_per_l") "avg_price_per_l_abs" FROM "__step_1_dummy__"',
'WITH __step_0_dummy__ AS (SELECT "price_per_l","alcohol_degree","name","cost","beer_kind","volume_ml","brewing_date","nullable_name" FROM "beers_tiny") ,__step_1_dummy__ AS (SELECT "beer_kind",COUNT(*) "beer_count",AVG("price_per_l") "avg_price_per_l" FROM "__step_0_dummy__" GROUP BY "beer_kind" ORDER BY "beer_kind" ASC NULLS LAST) SELECT "beer_kind","beer_count","avg_price_per_l",ABS("avg_price_per_l") "avg_price_per_l_abs" FROM "__step_1_dummy__"',
),
(
[
Expand Down

0 comments on commit 39d5d00

Please sign in to comment.