Skip to content

Commit

Permalink
Merge branch 'main' into enh/pivot_combine
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Oct 20, 2023
2 parents 57240bd + f67ffd4 commit 7545bd3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,17 @@ def _get_successor_required_columns(self, data: TileableData) -> Set[Any]:
"""
successors = self._get_successors(data)
if successors:
return set().union(
res = set().union(
*[self._context[successor][data] for successor in successors]
)
# When getting the required columns of a DataFrameIndex node, we need to consider itself.
if (
isinstance(data, BaseDataFrameData)
and isinstance(data.op, DataFrameIndex)
and len(data.dtypes) > 0
):
res = res.union(set(data.dtypes.index))
return res
else:
return self._get_all_columns(data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import tempfile

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -598,3 +599,42 @@ def test_setitem(setup, gen_data1):
raw1["c5"] = raw2["c1"]
expected = raw1.groupby(by="c1", as_index=False).sum()["c2"]
pd.testing.assert_series_equal(r.execute().fetch(), expected)


def test_merge_index_groupby_agg(setup, gen_data1):
file_path, file_path2 = gen_data1
left = md.read_csv(file_path)
right = md.read_csv(file_path2)
r = left.merge(right, on="c1")
data = r[["c1", "c2_x", "c2_y", "c4_x", "c4_y"]]

def udf(x):
return np.sum(x)

res = data.groupby("c1").agg({"c2_x": udf})

graph = res.build_graph()
optimize(graph)

agg_node = graph.result_tileables[0]
assert isinstance(agg_node.op, DataFrameGroupByAgg)

assert len(graph.predecessors(agg_node)) == 1
index_node = graph.predecessors(agg_node)[0]
assert type(index_node.op) is DataFrameIndex
assert set(index_node.op.col_names) == {"c1", "c2_x"}

index_node2 = graph.predecessors(index_node)[0]
assert type(index_node2.op) is DataFrameIndex
assert set(index_node2.op.col_names) == {"c1", "c2_x", "c2_y", "c4_x", "c4_y"}

merge_node = graph.predecessors(index_node2)[0]
assert type(merge_node.op) is DataFrameMerge

read_csv_node_left, read_csv_node_right = graph.predecessors(merge_node)
assert type(read_csv_node_left.op) is DataFrameReadCSV
assert type(read_csv_node_right.op) is DataFrameReadCSV
assert len(read_csv_node_left.op.usecols) == 3
assert len(read_csv_node_right.op.usecols) == 3
assert set(read_csv_node_left.op.usecols) == {"c1", "c2", "c4"}
assert set(read_csv_node_right.op.usecols) == {"c1", "c2", "c4"}

0 comments on commit 7545bd3

Please sign in to comment.