Skip to content

Commit

Permalink
FIX-#6781: Use pandas.api.types.pandas_dtype to convert to valid nu…
Browse files Browse the repository at this point in the history
…mpy and pandas only dtypes (#6788)

Signed-off-by: JignyasAnand <[email protected]>
  • Loading branch information
JignyasAnand authored Dec 1, 2023
1 parent b8b8434 commit 68c69f8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion modin/core/dataframe/pandas/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,8 @@ def _compute_tree_reduce_metadata(self, axis, new_parts, dtypes=None):
dtypes = self.copy_dtypes_cache()
elif dtypes is not None:
dtypes = pandas.Series(
[np.dtype(dtypes)] * len(new_axes[1]), index=new_axes[1]
[pandas.api.types.pandas_dtype(dtypes)] * len(new_axes[1]),
index=new_axes[1],
)

result = self.__constructor__(
Expand Down
10 changes: 9 additions & 1 deletion modin/pandas/test/dataframe/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pandas._testing import assert_series_equal

import modin.pandas as pd
from modin.config import NPartitions, StorageFormat
from modin.config import Engine, NPartitions, StorageFormat
from modin.pandas.test.utils import (
arg_keys,
assert_dtypes_equal,
Expand Down Expand Up @@ -306,6 +306,14 @@ def test_sum(data, axis, skipna, is_transposed):
df_equals(modin_result, pandas_result)


@pytest.mark.skipif(Engine.get() == "Native", reason="Fails on HDK")
@pytest.mark.parametrize("dtype", ["int64", "Int64"])
def test_dtype_consistency(dtype):
# test for issue #6781
res_dtype = pd.DataFrame([1, 2, 3, 4], dtype=dtype).sum().dtype
assert res_dtype == pandas.api.types.pandas_dtype(dtype)


@pytest.mark.parametrize("fn", ["prod, sum"])
@pytest.mark.parametrize(
"numeric_only", bool_arg_values, ids=arg_keys("numeric_only", bool_arg_keys)
Expand Down

0 comments on commit 68c69f8

Please sign in to comment.