diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index e2fc4f033a..b98f4c1f5d 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -class ReadSqlSourceNodeCounter(DataflowDagWalker[int]): +class _ReadSqlSourceNodeCounter(DataflowDagWalker[int]): """Counts the number of ReadSqlSourceNodes in the dataflow plan.""" @override @@ -43,8 +43,16 @@ def default_visit_action(self, current_node: DataflowPlanNode, inputs: Sequence[ def visit_source_node(self, node: ReadSqlSourceNode) -> int: # noqa: D102 return 1 - def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 - return dataflow_plan.checked_sink_node.accept(self) + +class DataflowPlanLookup: + """A lookup class to get assorted properties about the dataflow plan.""" + + def __init__(self, dataflow_plan: DataflowPlan) -> None: # noqa: D107 + self._dataflow_plan_sink_node = dataflow_plan.checked_sink_node + + def source_node_count(self) -> int: + """Counts the number of `ReadSqlSourceNodes` in the dataflow plan.""" + return self._dataflow_plan_sink_node.accept(_ReadSqlSourceNodeCounter()) def check_optimization( # noqa: D103 @@ -70,8 +78,8 @@ def check_optimization( # noqa: D103 dag_graph=dataflow_plan, ) - source_counter = ReadSqlSourceNodeCounter() - assert source_counter.count_source_nodes(dataflow_plan) == expected_num_sources_in_unoptimized + dataflow_plan_lookup = DataflowPlanLookup(dataflow_plan) + assert dataflow_plan_lookup.source_node_count() == expected_num_sources_in_unoptimized optimizer = SourceScanOptimizer() optimized_dataflow_plan = optimizer.optimize(dataflow_plan) @@ -88,7 +96,9 @@ def check_optimization( # noqa: D103 mf_test_configuration=mf_test_configuration, dag_graph=optimized_dataflow_plan, ) - assert source_counter.count_source_nodes(optimized_dataflow_plan) == expected_num_sources_in_optimized + + optimized_dataflow_plan_lookup = DataflowPlanLookup(optimized_dataflow_plan) + assert optimized_dataflow_plan_lookup.source_node_count() == expected_num_sources_in_optimized @pytest.mark.sql_engine_snapshot