Skip to content

Commit

Permalink
FIX-#6732: Fix inferring result dtypes for binary operations (#6737)
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitry Chigarev <[email protected]>
  • Loading branch information
dchigarev authored Nov 16, 2023
1 parent 2c6472c commit 1b36f4c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 68 deletions.
108 changes: 56 additions & 52 deletions modin/core/dataframe/algebra/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,12 @@
from .operator import Operator


def coerce_int_to_float64(dtype: np.dtype) -> np.dtype:
"""
Coerce dtype to float64 if it is a variant of integer.
If dtype is integer, function returns float64 datatype.
If not, returns the datatype argument itself.
Parameters
----------
dtype : np.dtype
NumPy datatype.
Returns
-------
dtype : np.dtype
Returns float64 for all int datatypes or returns the datatype itself
for other types.
Notes
-----
Used to precompute datatype in case of division in pandas.
"""
if dtype in np.sctypes["int"] + np.sctypes["uint"]:
return np.dtype(np.float64)
else:
return dtype


def maybe_compute_dtypes_common_cast(
first, second, trigger_computations=False, axis=0
first,
second,
trigger_computations=False,
axis=0,
func=None,
) -> Optional[pandas.Series]:
"""
Precompute data types for binary operations by finding common type between operands.
Expand All @@ -70,6 +46,9 @@ def maybe_compute_dtypes_common_cast(
have materialized dtypes.
axis : int, default: 0
Axis to perform the binary operation along.
func : callable(pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame, optional
If specified, will use this function to perform the "try_sample" method
(see ``Binary.register()`` docs for more details).
Returns
-------
Expand Down Expand Up @@ -138,18 +117,33 @@ def maybe_compute_dtypes_common_cast(

# If at least one column doesn't match, the result of the non matching column would be nan.
nan_dtype = np.dtype(type(np.nan))
dtypes = pandas.Series(
[
pandas.core.dtypes.cast.find_common_type(
[
dtypes_first[x],
dtypes_second[x],
]
dtypes = None
if func is not None:
try:
df1 = pandas.DataFrame([[1] * len(common_columns)]).astype(
{i: dtypes_first[col] for i, col in enumerate(common_columns)}
)
for x in common_columns
],
index=common_columns,
)
df2 = pandas.DataFrame([[1] * len(common_columns)]).astype(
{i: dtypes_second[col] for i, col in enumerate(common_columns)}
)
dtypes = func(df1, df2).dtypes.set_axis(common_columns)
# it sometimes doesn't work correctly with strings, so falling back to
# the "common_cast" method in this case
except TypeError:
pass
if dtypes is None:
dtypes = pandas.Series(
[
pandas.core.dtypes.cast.find_common_type(
[
dtypes_first[x],
dtypes_second[x],
]
)
for x in common_columns
],
index=common_columns,
)
dtypes = pandas.concat(
[
dtypes,
Expand Down Expand Up @@ -211,7 +205,9 @@ def maybe_build_dtypes_series(
return dtypes


def try_compute_new_dtypes(first, second, infer_dtypes=None, result_dtype=None, axis=0):
def try_compute_new_dtypes(
first, second, infer_dtypes=None, result_dtype=None, axis=0, func=None
):
"""
Precompute resulting dtypes of the binary operation if possible.
Expand All @@ -225,12 +221,14 @@ def try_compute_new_dtypes(first, second, infer_dtypes=None, result_dtype=None,
First operand of the binary operation.
second : PandasQueryCompiler, list-like or scalar
Second operand of the binary operation.
infer_dtypes : {"common_cast", "float", "bool", None}, default: None
infer_dtypes : {"common_cast", "try_sample", "bool", None}, default: None
How dtypes should be infered (see ``Binary.register`` doc for more info).
result_dtype : np.dtype, optional
NumPy dtype of the result. If not specified it will be inferred from the `infer_dtypes` parameter.
axis : int, default: 0
Axis to perform the binary operation along.
func : callable(pandas.DataFrame, pandas.DataFrame) -> pandas.DataFrame, optional
A callable to be used for the "try_sample" method.
Returns
-------
Expand All @@ -243,11 +241,17 @@ def try_compute_new_dtypes(first, second, infer_dtypes=None, result_dtype=None,
if infer_dtypes == "bool" or is_bool_dtype(result_dtype):
dtypes = maybe_build_dtypes_series(first, second, dtype=np.dtype(bool))
elif infer_dtypes == "common_cast":
dtypes = maybe_compute_dtypes_common_cast(first, second, axis=axis)
elif infer_dtypes == "float":
dtypes = maybe_compute_dtypes_common_cast(first, second, axis=axis)
if dtypes is not None:
dtypes = dtypes.apply(coerce_int_to_float64)
dtypes = maybe_compute_dtypes_common_cast(
first, second, axis=axis, func=None
)
elif infer_dtypes == "try_sample":
if func is None:
raise ValueError(
"'func' must be specified if dtypes infering method is 'try_sample'"
)
dtypes = maybe_compute_dtypes_common_cast(
first, second, axis=axis, func=func
)
else:
# For now we only know how to handle `result_dtype == bool` as that's
# the only value that is being passed here right now, it's unclear
Expand Down Expand Up @@ -283,12 +287,12 @@ def register(
labels : {"keep", "replace", "drop"}, default: "replace"
Whether keep labels from left Modin DataFrame, replace them with labels
from joined DataFrame or drop altogether to make them be computed lazily later.
infer_dtypes : {"common_cast", "float", "bool", None}, default: None
infer_dtypes : {"common_cast", "try_sample", "bool", None}, default: None
How dtypes should be inferred.
* If "common_cast", casts to common dtype of operand columns.
* If "float", performs type casting by finding common dtype.
If the common dtype is any of the integer types, perform type casting to float.
Used in case of truediv.
* If "try_sample", creates small pandas DataFrames with dtypes of operands and
runs the `func` on them to determine output dtypes. If a ``TypeError`` is raised
during this process, fallback to "common_cast" method.
* If "bool", dtypes would be a boolean series with same size as that of operands.
* If ``None``, do not infer new dtypes (they will be computed manually once accessed).
Expand Down Expand Up @@ -339,7 +343,7 @@ def caller(
other = other.transpose()
if dtypes != "copy":
dtypes = try_compute_new_dtypes(
query_compiler, other, infer_dtypes, dtypes, axis
query_compiler, other, infer_dtypes, dtypes, axis, func
)

shape_hint = None
Expand Down
33 changes: 18 additions & 15 deletions modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,34 +388,37 @@ def to_numpy(self, **kwargs):
# such that columns/rows that don't have an index on the other DataFrame
# result in NaN values.

add = Binary.register(pandas.DataFrame.add, infer_dtypes="common_cast")
add = Binary.register(pandas.DataFrame.add, infer_dtypes="try_sample")
# 'combine' and 'combine_first' are working with UDFs, so it's better not so sample them
combine = Binary.register(pandas.DataFrame.combine, infer_dtypes="common_cast")
combine_first = Binary.register(pandas.DataFrame.combine_first, infer_dtypes="bool")
combine_first = Binary.register(
pandas.DataFrame.combine_first, infer_dtypes="common_cast"
)
eq = Binary.register(pandas.DataFrame.eq, infer_dtypes="bool")
equals = Binary.register(
lambda df, other: pandas.DataFrame([[df.equals(other)]]),
join_type=None,
labels="drop",
infer_dtypes="bool",
)
floordiv = Binary.register(pandas.DataFrame.floordiv, infer_dtypes="common_cast")
floordiv = Binary.register(pandas.DataFrame.floordiv, infer_dtypes="try_sample")
ge = Binary.register(pandas.DataFrame.ge, infer_dtypes="bool")
gt = Binary.register(pandas.DataFrame.gt, infer_dtypes="bool")
le = Binary.register(pandas.DataFrame.le, infer_dtypes="bool")
lt = Binary.register(pandas.DataFrame.lt, infer_dtypes="bool")
mod = Binary.register(pandas.DataFrame.mod, infer_dtypes="common_cast")
mul = Binary.register(pandas.DataFrame.mul, infer_dtypes="common_cast")
rmul = Binary.register(pandas.DataFrame.rmul, infer_dtypes="common_cast")
mod = Binary.register(pandas.DataFrame.mod, infer_dtypes="try_sample")
mul = Binary.register(pandas.DataFrame.mul, infer_dtypes="try_sample")
rmul = Binary.register(pandas.DataFrame.rmul, infer_dtypes="try_sample")
ne = Binary.register(pandas.DataFrame.ne, infer_dtypes="bool")
pow = Binary.register(pandas.DataFrame.pow, infer_dtypes="common_cast")
radd = Binary.register(pandas.DataFrame.radd, infer_dtypes="common_cast")
rfloordiv = Binary.register(pandas.DataFrame.rfloordiv, infer_dtypes="common_cast")
rmod = Binary.register(pandas.DataFrame.rmod, infer_dtypes="common_cast")
rpow = Binary.register(pandas.DataFrame.rpow, infer_dtypes="common_cast")
rsub = Binary.register(pandas.DataFrame.rsub, infer_dtypes="common_cast")
rtruediv = Binary.register(pandas.DataFrame.rtruediv, infer_dtypes="float")
sub = Binary.register(pandas.DataFrame.sub, infer_dtypes="common_cast")
truediv = Binary.register(pandas.DataFrame.truediv, infer_dtypes="float")
pow = Binary.register(pandas.DataFrame.pow, infer_dtypes="try_sample")
radd = Binary.register(pandas.DataFrame.radd, infer_dtypes="try_sample")
rfloordiv = Binary.register(pandas.DataFrame.rfloordiv, infer_dtypes="try_sample")
rmod = Binary.register(pandas.DataFrame.rmod, infer_dtypes="try_sample")
rpow = Binary.register(pandas.DataFrame.rpow, infer_dtypes="try_sample")
rsub = Binary.register(pandas.DataFrame.rsub, infer_dtypes="try_sample")
rtruediv = Binary.register(pandas.DataFrame.rtruediv, infer_dtypes="try_sample")
sub = Binary.register(pandas.DataFrame.sub, infer_dtypes="try_sample")
truediv = Binary.register(pandas.DataFrame.truediv, infer_dtypes="try_sample")
__and__ = Binary.register(pandas.DataFrame.__and__, infer_dtypes="bool")
__or__ = Binary.register(pandas.DataFrame.__or__, infer_dtypes="bool")
__rand__ = Binary.register(pandas.DataFrame.__rand__, infer_dtypes="bool")
Expand Down
60 changes: 59 additions & 1 deletion modin/pandas/test/dataframe/test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

import modin.pandas as pd
from modin.config import NPartitions, StorageFormat
from modin.config import Engine, NPartitions, StorageFormat
from modin.core.dataframe.pandas.partitioning.axis_partition import (
PandasDataframeAxisPartition,
)
Expand Down Expand Up @@ -433,3 +433,61 @@ def test_non_commutative_multiply():
integer = NonCommutativeMultiplyInteger(2)
eval_general(modin_df, pandas_df, lambda s: integer * s)
eval_general(modin_df, pandas_df, lambda s: s * integer)


@pytest.mark.parametrize(
"op",
[
*("add", "radd", "sub", "rsub", "mod", "rmod", "pow", "rpow"),
*("truediv", "rtruediv", "mul", "rmul", "floordiv", "rfloordiv"),
],
)
@pytest.mark.parametrize(
"val1",
[
pytest.param([10, 20], id="int"),
pytest.param([10, True], id="obj"),
pytest.param(
[True, True],
id="bool",
marks=pytest.mark.skipif(
condition=Engine.get() == "Native", reason="Fails on HDK"
),
),
pytest.param([3.5, 4.5], id="float"),
],
)
@pytest.mark.parametrize(
"val2",
[
pytest.param([10, 20], id="int"),
pytest.param([10, True], id="obj"),
pytest.param(
[True, True],
id="bool",
marks=pytest.mark.skipif(
condition=Engine.get() == "Native", reason="Fails on HDK"
),
),
pytest.param([3.5, 4.5], id="float"),
pytest.param(2, id="int scalar"),
pytest.param(
True,
id="bool scalar",
marks=pytest.mark.skipif(
condition=Engine.get() == "Native", reason="Fails on HDK"
),
),
pytest.param(3.5, id="float scalar"),
],
)
def test_arithmetic_with_tricky_dtypes(val1, val2, op):
modin_df1, pandas_df1 = create_test_dfs(val1)
modin_df2, pandas_df2 = (
create_test_dfs(val2) if isinstance(val2, list) else (val2, val2)
)
eval_general(
(modin_df1, modin_df2),
(pandas_df1, pandas_df2),
lambda dfs: getattr(dfs[0], op)(dfs[1]),
)

0 comments on commit 1b36f4c

Please sign in to comment.