Skip to content

Commit

Permalink
fix: keep null values in pivot index columns (TCTC-9902) (#2314)
Browse files Browse the repository at this point in the history
* fix: keep null values in pivot index columns

* test: fix pandas pivot benchmark test
  • Loading branch information
julien-pinchelimouroux authored Jan 10, 2025
1 parent 39d5d00 commit 0ac518b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
1 change: 1 addition & 0 deletions server/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Fixed

- Pandas: the `pivot` step now keeps null values in index columns and handle them as any other values.
- Pandas & Pypika: the `count` aggregation of the aggregate step now properly counts nulls

## [0.48.6] - 2024-12-11
Expand Down
13 changes: 13 additions & 0 deletions server/src/weaverbird/backends/pandas_executor/steps/pivot.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
from numpy import nan as nan_value
from pandas import DataFrame

from weaverbird.backends.pandas_executor.types import DomainRetriever, PipelineExecutor
from weaverbird.pipeline.steps import PivotStep

PIVOT_NULL_VALUE = "__WEAVERBIRD_PIVOT_NULL_VALUE__"


def execute_pivot(
step: PivotStep,
df: DataFrame,
domain_retriever: DomainRetriever = None,
execute_pipeline: PipelineExecutor = None,
) -> DataFrame:
# Create alias for null values in indexed columns before running pivot to avoid removing null values
for idx_column in step.index:
df[idx_column] = df[idx_column].fillna(PIVOT_NULL_VALUE)
df[PIVOT_NULL_VALUE] = nan_value

pivoted_df = df.pivot_table(
values=step.value_column,
index=step.index,
columns=step.column_to_pivot,
aggfunc="mean" if step.agg_function == "avg" else step.agg_function,
).reset_index()

# Replace alias with null value
for idx_column in step.index:
pivoted_df[idx_column] = pivoted_df[idx_column].replace(PIVOT_NULL_VALUE, nan_value)

pivoted_df.columns.name = None
return pivoted_df
68 changes: 68 additions & 0 deletions server/tests/backends/fixtures/pivot/with_nulls.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
exclude:
- athena_pypika
- bigquery_pypika
- mysql_pypika
- postgres_pypika
- redshift_pypika
- snowflake_pypika
step:
pipeline:
- name: pivot
index:
- LABEL
- TYPE
column_to_pivot: ALPHA
value_column: COST
agg_function: sum
input:
schema:
fields:
- name: LABEL
type: integer
- name: TYPE
type: string
- name: COST
type: integer
- name: ALPHA
type: string
pandas_version: 0.20.0
data:
- LABEL: 1
TYPE: PARENT
COST: 5
ALPHA: ALPHA
- LABEL: 2
TYPE: PARENT
COST: 100
ALPHA: ALPHA
- LABEL: 1
TYPE:
COST: 2
ALPHA: ALPHA
- LABEL: 2
TYPE:
COST: 28
ALPHA: ALPHA
expected:
schema:
fields:
- name: LABEL
type: integer
- name: TYPE
type: string
- name: ALPHA
type: integer
pandas_version: 0.20.0
data:
- LABEL: 1
TYPE: PARENT
ALPHA: 5
- LABEL: 1
TYPE:
ALPHA: 2
- LABEL: 2
TYPE: PARENT
ALPHA: 100
- LABEL: 2
TYPE:
ALPHA: 28
2 changes: 1 addition & 1 deletion server/tests/steps/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_benchmark_pivot(benchmark):

step = PivotStep(
name="pivot",
index=["group"],
index=["id"],
column_to_pivot="group",
value_column="value",
agg_function="avg",
Expand Down

0 comments on commit 0ac518b

Please sign in to comment.