From 68c69f81bb73a0c499b21fb9b3f90d407872a6dc Mon Sep 17 00:00:00 2001 From: Jignyas Anand Siripurapu <93654470+JignyasAnand@users.noreply.github.com> Date: Fri, 1 Dec 2023 20:24:26 +0530 Subject: [PATCH] FIX-#6781: Use `pandas.api.types.pandas_dtype` to convert to valid numpy and pandas only dtypes (#6788) Signed-off-by: JignyasAnand --- modin/core/dataframe/pandas/dataframe/dataframe.py | 3 ++- modin/pandas/test/dataframe/test_reduce.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/modin/core/dataframe/pandas/dataframe/dataframe.py b/modin/core/dataframe/pandas/dataframe/dataframe.py index 7676e5eae53..8140ab0c43d 100644 --- a/modin/core/dataframe/pandas/dataframe/dataframe.py +++ b/modin/core/dataframe/pandas/dataframe/dataframe.py @@ -1960,7 +1960,8 @@ def _compute_tree_reduce_metadata(self, axis, new_parts, dtypes=None): dtypes = self.copy_dtypes_cache() elif dtypes is not None: dtypes = pandas.Series( - [np.dtype(dtypes)] * len(new_axes[1]), index=new_axes[1] + [pandas.api.types.pandas_dtype(dtypes)] * len(new_axes[1]), + index=new_axes[1], ) result = self.__constructor__( diff --git a/modin/pandas/test/dataframe/test_reduce.py b/modin/pandas/test/dataframe/test_reduce.py index ceda40e4161..6edee442de6 100644 --- a/modin/pandas/test/dataframe/test_reduce.py +++ b/modin/pandas/test/dataframe/test_reduce.py @@ -18,7 +18,7 @@ from pandas._testing import assert_series_equal import modin.pandas as pd -from modin.config import NPartitions, StorageFormat +from modin.config import Engine, NPartitions, StorageFormat from modin.pandas.test.utils import ( arg_keys, assert_dtypes_equal, @@ -306,6 +306,14 @@ def test_sum(data, axis, skipna, is_transposed): df_equals(modin_result, pandas_result) +@pytest.mark.skipif(Engine.get() == "Native", reason="Fails on HDK") +@pytest.mark.parametrize("dtype", ["int64", "Int64"]) +def test_dtype_consistency(dtype): + # test for issue #6781 + res_dtype = pd.DataFrame([1, 2, 3, 4], dtype=dtype).sum().dtype + assert res_dtype == pandas.api.types.pandas_dtype(dtype) + + @pytest.mark.parametrize("fn", ["prod, sum"]) @pytest.mark.parametrize( "numeric_only", bool_arg_values, ids=arg_keys("numeric_only", bool_arg_keys)