-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: keep null values in pivot index columns
- Loading branch information
1 parent
031f919
commit 9d0ff42
Showing
3 changed files
with
85 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
server/src/weaverbird/backends/pandas_executor/steps/pivot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,33 @@ | ||
from numpy import nan as nan_value | ||
from pandas import DataFrame | ||
|
||
from weaverbird.backends.pandas_executor.types import DomainRetriever, PipelineExecutor | ||
from weaverbird.pipeline.steps import PivotStep | ||
|
||
PIVOT_NULL_VALUE = "__WEAVERBIRD_PIVOT_NULL_VALUE__" | ||
|
||
|
||
def execute_pivot( | ||
step: PivotStep, | ||
df: DataFrame, | ||
domain_retriever: DomainRetriever = None, | ||
execute_pipeline: PipelineExecutor = None, | ||
) -> DataFrame: | ||
# Create alias for null values in indexed columns before running pivot to avoid removing null values | ||
for idx_column in step.index: | ||
df[idx_column] = df[idx_column].fillna(PIVOT_NULL_VALUE) | ||
df[PIVOT_NULL_VALUE] = nan_value | ||
|
||
pivoted_df = df.pivot_table( | ||
values=step.value_column, | ||
index=step.index, | ||
columns=step.column_to_pivot, | ||
aggfunc="mean" if step.agg_function == "avg" else step.agg_function, | ||
).reset_index() | ||
|
||
# Replace alias with null value | ||
for idx_column in step.index: | ||
pivoted_df[idx_column] = pivoted_df[idx_column].replace(PIVOT_NULL_VALUE, nan_value) | ||
|
||
pivoted_df.columns.name = None | ||
return pivoted_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
exclude: | ||
- athena_pypika | ||
- bigquery_pypika | ||
- mysql_pypika | ||
- postgres_pypika | ||
- redshift_pypika | ||
- snowflake_pypika | ||
step: | ||
pipeline: | ||
- name: pivot | ||
index: | ||
- LABEL | ||
- TYPE | ||
column_to_pivot: ALPHA | ||
value_column: COST | ||
agg_function: sum | ||
input: | ||
schema: | ||
fields: | ||
- name: LABEL | ||
type: integer | ||
- name: TYPE | ||
type: string | ||
- name: COST | ||
type: integer | ||
- name: ALPHA | ||
type: string | ||
pandas_version: 0.20.0 | ||
data: | ||
- LABEL: 1 | ||
TYPE: PARENT | ||
COST: 5 | ||
ALPHA: ALPHA | ||
- LABEL: 2 | ||
TYPE: PARENT | ||
COST: 100 | ||
ALPHA: ALPHA | ||
- LABEL: 1 | ||
TYPE: | ||
COST: 2 | ||
ALPHA: ALPHA | ||
- LABEL: 2 | ||
TYPE: | ||
COST: 28 | ||
ALPHA: ALPHA | ||
expected: | ||
schema: | ||
fields: | ||
- name: LABEL | ||
type: integer | ||
- name: TYPE | ||
type: string | ||
- name: ALPHA | ||
type: integer | ||
pandas_version: 0.20.0 | ||
data: | ||
- LABEL: 1 | ||
TYPE: PARENT | ||
ALPHA: 5 | ||
- LABEL: 1 | ||
TYPE: | ||
ALPHA: 2 | ||
- LABEL: 2 | ||
TYPE: PARENT | ||
ALPHA: 100 | ||
- LABEL: 2 | ||
TYPE: | ||
ALPHA: 28 |