Skip to content

Commit

Permalink
FIX-#6752: Preserve dtypes cache on '.insert()' (#6757)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev authored Nov 20, 2023
1 parent 0ba2a46 commit 257de20
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 3 deletions.
9 changes: 8 additions & 1 deletion modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2804,6 +2804,7 @@ def apply_full_axis_select_indices(
new_index=None,
new_columns=None,
keep_remaining=False,
new_dtypes=None,
):
"""
Apply a function across an entire axis for a subset of the data.
Expand All @@ -2826,6 +2827,10 @@ def apply_full_axis_select_indices(
advance, and if not provided it must be computed.
keep_remaining : boolean, default: False
Whether or not to drop the data that is not computed over.
new_dtypes : ModinDtypes or pandas.Series, optional
The data types of the result. This is an optimization
because there are functions that always result in a particular data
type, and allows us to avoid (re)computing it.
Returns
-------
Expand Down Expand Up @@ -2854,7 +2859,9 @@ def apply_full_axis_select_indices(
new_index = self.index if axis == 1 else None
if new_columns is None:
new_columns = self.columns if axis == 0 else None
return self.__constructor__(new_partitions, new_index, new_columns, None, None)
return self.__constructor__(
new_partitions, new_index, new_columns, None, None, dtypes=new_dtypes
)

@lazy_metadata_decorator(apply_axis="both")
def apply_select_indices(
Expand Down
22 changes: 20 additions & 2 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
SeriesGroupByDefault,
)
from modin.core.dataframe.base.dataframe.utils import join_columns
from modin.core.dataframe.pandas.metadata import ModinDtypes
from modin.core.dataframe.pandas.metadata import DtypesDescriptor, ModinDtypes
from modin.core.storage_formats.base.query_compiler import BaseQueryCompiler
from modin.error_message import ErrorMessage
from modin.utils import (
Expand Down Expand Up @@ -3112,6 +3112,23 @@ def insert(df, internal_indices=[]): # pragma: no cover
df.insert(internal_idx, column, value)
return df

if hasattr(value, "dtype"):
value_dtype = value.dtype
elif is_scalar(value):
value_dtype = np.dtype(type(value))
else:
value_dtype = np.array(value).dtype

new_columns = self.columns.insert(loc, column)
new_dtypes = ModinDtypes.concat(
[
self._modin_frame._dtypes,
DtypesDescriptor({column: value_dtype}, cols_with_unknown_dtypes=[]),
]
).lazy_get(
new_columns
) # get dtypes in a proper order

# TODO: rework by passing list-like values to `apply_select_indices`
# as an item to distribute
new_modin_frame = self._modin_frame.apply_full_axis_select_indices(
Expand All @@ -3120,7 +3137,8 @@ def insert(df, internal_indices=[]): # pragma: no cover
numeric_indices=[loc],
keep_remaining=True,
new_index=self.index,
new_columns=self.columns.insert(loc, column),
new_columns=new_columns,
new_dtypes=new_dtypes,
)
return self.__constructor__(new_modin_frame)

Expand Down
55 changes: 55 additions & 0 deletions modin/test/storage_formats/pandas/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -1965,6 +1965,61 @@ class TestZeroComputationDtypes:
Test cases that shouldn't trigger dtypes computation during their execution.
"""

@pytest.mark.parametrize("self_dtype", ["materialized", "partial", "unknown"])
@pytest.mark.parametrize(
"value, value_dtype",
[
[3.5, np.dtype(float)],
[[3.5, 2.4], np.dtype(float)],
[np.array([3.5, 2.4]), np.dtype(float)],
[pd.Series([3.5, 2.4]), np.dtype(float)],
],
)
def test_preserve_dtypes_insert(self, self_dtype, value, value_dtype):
with mock.patch.object(PandasDataframe, "_compute_dtypes") as patch:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
if self_dtype == "materialized":
assert df._query_compiler._modin_frame.has_materialized_dtypes
elif self_dtype == "partial":
df._query_compiler._modin_frame.set_dtypes_cache(
ModinDtypes(
DtypesDescriptor(
{"a": np.dtype(int)}, cols_with_unknown_dtypes=["b"]
)
)
)
elif self_dtype == "unknown":
df._query_compiler._modin_frame.set_dtypes_cache(None)
else:
raise NotImplementedError(self_dtype)

df.insert(loc=0, column="c", value=value)

if self_dtype == "materialized":
result_dtype = pandas.Series(
[value_dtype, np.dtype(int), np.dtype(int)], index=["c", "a", "b"]
)
assert df._query_compiler._modin_frame.has_materialized_dtypes
assert df.dtypes.equals(result_dtype)
elif self_dtype == "partial":
result_dtype = DtypesDescriptor(
{"a": np.dtype(int), "c": value_dtype},
cols_with_unknown_dtypes=["b"],
columns_order={0: "c", 1: "a", 2: "b"},
)
df._query_compiler._modin_frame._dtypes._value.equals(result_dtype)
elif self_dtype == "unknown":
result_dtype = DtypesDescriptor(
{"c": value_dtype},
cols_with_unknown_dtypes=["a", "b"],
columns_order={0: "c", 1: "a", 2: "b"},
)
df._query_compiler._modin_frame._dtypes._value.equals(result_dtype)
else:
raise NotImplementedError(self_dtype)

patch.assert_not_called()

def test_get_dummies_case(self):
with mock.patch.object(PandasDataframe, "_compute_dtypes") as patch:
df = pd.DataFrame(
Expand Down

0 comments on commit 257de20

Please sign in to comment.