diff --git a/server/CHANGELOG.md b/server/CHANGELOG.md index f51386f700..8284dda859 100644 --- a/server/CHANGELOG.md +++ b/server/CHANGELOG.md @@ -4,6 +4,7 @@ ### Fixed +- Pandas: the `pivot` step now keeps null values in index columns and handle them as any other values. - Pandas & Pypika: the `count` aggregation of the aggregate step now properly counts nulls ## [0.48.6] - 2024-12-11 diff --git a/server/src/weaverbird/backends/pandas_executor/steps/pivot.py b/server/src/weaverbird/backends/pandas_executor/steps/pivot.py index 657ac65127..ee980a9182 100644 --- a/server/src/weaverbird/backends/pandas_executor/steps/pivot.py +++ b/server/src/weaverbird/backends/pandas_executor/steps/pivot.py @@ -1,8 +1,11 @@ +from numpy import nan as nan_value from pandas import DataFrame from weaverbird.backends.pandas_executor.types import DomainRetriever, PipelineExecutor from weaverbird.pipeline.steps import PivotStep +PIVOT_NULL_VALUE = "__WEAVERBIRD_PIVOT_NULL_VALUE__" + def execute_pivot( step: PivotStep, @@ -10,11 +13,21 @@ def execute_pivot( domain_retriever: DomainRetriever = None, execute_pipeline: PipelineExecutor = None, ) -> DataFrame: + # Create alias for null values in indexed columns before running pivot to avoid removing null values + for idx_column in step.index: + df[idx_column] = df[idx_column].fillna(PIVOT_NULL_VALUE) + df[PIVOT_NULL_VALUE] = nan_value + pivoted_df = df.pivot_table( values=step.value_column, index=step.index, columns=step.column_to_pivot, aggfunc="mean" if step.agg_function == "avg" else step.agg_function, ).reset_index() + + # Replace alias with null value + for idx_column in step.index: + pivoted_df[idx_column] = pivoted_df[idx_column].replace(PIVOT_NULL_VALUE, nan_value) + pivoted_df.columns.name = None return pivoted_df diff --git a/server/tests/backends/fixtures/pivot/with_nulls.yaml b/server/tests/backends/fixtures/pivot/with_nulls.yaml new file mode 100644 index 0000000000..3bfaa54bfa --- /dev/null +++ b/server/tests/backends/fixtures/pivot/with_nulls.yaml @@ -0,0 +1,68 @@ +exclude: +- athena_pypika +- bigquery_pypika +- mysql_pypika +- postgres_pypika +- redshift_pypika +- snowflake_pypika +step: + pipeline: + - name: pivot + index: + - LABEL + - TYPE + column_to_pivot: ALPHA + value_column: COST + agg_function: sum +input: + schema: + fields: + - name: LABEL + type: integer + - name: TYPE + type: string + - name: COST + type: integer + - name: ALPHA + type: string + pandas_version: 0.20.0 + data: + - LABEL: 1 + TYPE: PARENT + COST: 5 + ALPHA: ALPHA + - LABEL: 2 + TYPE: PARENT + COST: 100 + ALPHA: ALPHA + - LABEL: 1 + TYPE: + COST: 2 + ALPHA: ALPHA + - LABEL: 2 + TYPE: + COST: 28 + ALPHA: ALPHA +expected: + schema: + fields: + - name: LABEL + type: integer + - name: TYPE + type: string + - name: ALPHA + type: integer + pandas_version: 0.20.0 + data: + - LABEL: 1 + TYPE: PARENT + ALPHA: 5 + - LABEL: 1 + TYPE: + ALPHA: 2 + - LABEL: 2 + TYPE: PARENT + ALPHA: 100 + - LABEL: 2 + TYPE: + ALPHA: 28 diff --git a/server/tests/steps/test_pivot.py b/server/tests/steps/test_pivot.py index c9d32d4fa2..62e5b44a76 100644 --- a/server/tests/steps/test_pivot.py +++ b/server/tests/steps/test_pivot.py @@ -72,7 +72,7 @@ def test_benchmark_pivot(benchmark): step = PivotStep( name="pivot", - index=["group"], + index=["id"], column_to_pivot="group", value_column="value", agg_function="avg",