From f67ffd4d6da7a1be76c2c5342421f56a01697a09 Mon Sep 17 00:00:00 2001 From: Chengjie Li <109656400+ChengjieLi28@users.noreply.github.com> Date: Thu, 19 Oct 2023 22:04:59 -0500 Subject: [PATCH] BUG: column pruning causes missing columns on `DataFrameIndex` op (#743) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .../column_pruning/column_pruning_rule.py | 10 ++++- .../tests/test_column_pruning.py | 40 +++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/column_pruning_rule.py b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/column_pruning_rule.py index a90544a61..9b7c77e8e 100644 --- a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/column_pruning_rule.py +++ b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/column_pruning_rule.py @@ -55,9 +55,17 @@ def _get_successor_required_columns(self, data: TileableData) -> Set[Any]: """ successors = self._get_successors(data) if successors: - return set().union( + res = set().union( *[self._context[successor][data] for successor in successors] ) + # When getting the required columns of a DataFrameIndex node, we need to consider itself. + if ( + isinstance(data, BaseDataFrameData) + and isinstance(data.op, DataFrameIndex) + and len(data.dtypes) > 0 + ): + res = res.union(set(data.dtypes.index)) + return res else: return self._get_all_columns(data) diff --git a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_column_pruning.py b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_column_pruning.py index 321a911d0..9e869f48b 100644 --- a/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_column_pruning.py +++ b/python/xorbits/_mars/optimization/logical/tileable/column_pruning/tests/test_column_pruning.py @@ -15,6 +15,7 @@ import os import tempfile +import numpy as np import pandas as pd import pytest @@ -598,3 +599,42 @@ def test_setitem(setup, gen_data1): raw1["c5"] = raw2["c1"] expected = raw1.groupby(by="c1", as_index=False).sum()["c2"] pd.testing.assert_series_equal(r.execute().fetch(), expected) + + +def test_merge_index_groupby_agg(setup, gen_data1): + file_path, file_path2 = gen_data1 + left = md.read_csv(file_path) + right = md.read_csv(file_path2) + r = left.merge(right, on="c1") + data = r[["c1", "c2_x", "c2_y", "c4_x", "c4_y"]] + + def udf(x): + return np.sum(x) + + res = data.groupby("c1").agg({"c2_x": udf}) + + graph = res.build_graph() + optimize(graph) + + agg_node = graph.result_tileables[0] + assert isinstance(agg_node.op, DataFrameGroupByAgg) + + assert len(graph.predecessors(agg_node)) == 1 + index_node = graph.predecessors(agg_node)[0] + assert type(index_node.op) is DataFrameIndex + assert set(index_node.op.col_names) == {"c1", "c2_x"} + + index_node2 = graph.predecessors(index_node)[0] + assert type(index_node2.op) is DataFrameIndex + assert set(index_node2.op.col_names) == {"c1", "c2_x", "c2_y", "c4_x", "c4_y"} + + merge_node = graph.predecessors(index_node2)[0] + assert type(merge_node.op) is DataFrameMerge + + read_csv_node_left, read_csv_node_right = graph.predecessors(merge_node) + assert type(read_csv_node_left.op) is DataFrameReadCSV + assert type(read_csv_node_right.op) is DataFrameReadCSV + assert len(read_csv_node_left.op.usecols) == 3 + assert len(read_csv_node_right.op.usecols) == 3 + assert set(read_csv_node_left.op.usecols) == {"c1", "c2", "c4"} + assert set(read_csv_node_right.op.usecols) == {"c1", "c2", "c4"}