Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with missed columns in the column pruner #1679

Merged
merged 2 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_aliases(node_to_retain_columns, column_aliases_to_retain)
sql_table_node = node_to_retain_columns.as_sql_table_node
if sql_table_node is not None and sql_table_node.sql_table.schema_name is None:
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=sql_table_node.sql_table.table_name,
column_aliases=column_aliases_to_retain,
)

# Visit recursively.
self._visit_parents(node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_name: test_column_reference_expression
test_filename: test_cte_column_pruner.py
docstring:
Test a column reference expression that does not specify a table alias.
expectation_description:
`cte_source_0__col_01` should be retained in the CTE.
---
optimizer:
SqlColumnPrunerOptimizer

sql_before_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
, test_table_alias.col_0 AS cte_source_0__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias

sql_after_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_name: test_string_expression
test_filename: test_cte_column_pruner.py
docstring:
Test a string expression that references a column in the cte.
expectation_description:
`cte_source_0__col_01` should be retained in the CTE.
---
optimizer:
SqlColumnPrunerOptimizer

sql_before_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
, test_table_alias.col_0 AS cte_source_0__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias

sql_after_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias
109 changes: 109 additions & 0 deletions tests_metricflow/sql/optimizer/test_cte_column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SqlColumnReferenceExpression,
SqlComparison,
SqlComparisonExpression,
SqlStringExpression,
)
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
Expand Down Expand Up @@ -464,3 +465,111 @@ def test_common_cte_aliases_in_nested_query(
"""
),
)


def test_string_expression(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_pruner: SqlColumnPrunerOptimizer,
sql_plan_renderer: DefaultSqlPlanRenderer,
) -> None:
"""Test a string expression that references a column in the cte."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
column_alias="top_level__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
),
from_source_alias="test_table_alias",
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=column_pruner,
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
)


def test_column_reference_expression(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_pruner: SqlColumnPrunerOptimizer,
sql_plan_renderer: DefaultSqlPlanRenderer,
) -> None:
"""Test a column reference expression that does not specify a table alias."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
column_alias="top_level__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
),
from_source_alias="test_table_alias",
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=column_pruner,
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
)