diff --git a/CHANGELOG.md b/CHANGELOG.md index 18d2d7a1257..2e516d6c6ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,7 @@ - %%: A literal '%' character. - Added support for `Series.between`. - Added support for `include_groups=False` in `DataFrameGroupBy.apply`. +- Added support for `expand=True` in `Series.str.split`. #### Bug Fixes diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index 702c699f153..7c96a0ac362 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -119,7 +119,7 @@ the method in the left column. | ``slice_replace`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, | -| | | `expand` is set, or `regex` is set. | +| | | or `regex` is set. | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool | | | | value. | diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index cf1867bc284..bdc2eb9b4d3 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -107,6 +107,7 @@ dense_rank, first_value, floor, + get, greatest, hour, iff, @@ -16813,10 +16814,6 @@ def str_split( ErrorMessage.not_implemented( "Snowpark pandas doesn't support non-str 'pat' argument" ) - if expand: - ErrorMessage.not_implemented( - "Snowpark pandas doesn't support 'expand' argument" - ) if regex: ErrorMessage.not_implemented( "Snowpark pandas doesn't support 'regex' argument" @@ -16864,6 +16861,12 @@ def output_col( if np.isnan(n): # Follow pandas behavior return pandas_lit(np.nan) + elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which based on our experiments, leaves the input column as is + # whenever the above condition is satisfied. + new_col = iff( + column.is_null(), pandas_lit(None), array_construct(column) + ) elif n <= 0: # If all possible splits are requested, we just use SQL's split function. new_col = builtin("split")(new_col, pandas_lit(new_pat)) @@ -16907,9 +16910,93 @@ def output_col( ) return self._replace_non_str(column, new_col) - new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( - lambda col_name: output_col(col_name, pat, n) - ) + def output_cols( + column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int + ) -> list[SnowparkColumn]: + """ + Returns the list of columns that the input column will be split into. + This is only used when expand=True. + Args: + column : SnowparkColumn + Input column + pat : str + String to split on + n : int + Limit on the number of output splits + max_splits : int + Maximum number of achievable splits across all values in the input column. + This is needed to be able to pad rows with fewer splits than desired with nulls. + """ + col = output_col(column, pat, n) + final_splits = 0 + + if np.isnan(n): + # Follow pandas behavior + final_splits = 1 + elif n <= 0: + final_splits = max_splits + else: + final_splits = min(n + 1, max_splits) + + if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which based on our experiments, leaves the input column as is + # whenever the above condition is satisfied. + final_splits = 1 + + return [ + iff( + array_size(col) > pandas_lit(i), + get(col, pandas_lit(i)), + pandas_lit(None), + ) + for i in range(final_splits) + ] + + def get_max_splits() -> int: + """ + Returns the maximum number of splits achievable + across all values stored in the input column. + """ + splits_as_list_frame = self.str_split( + pat=pat, + n=-1, + expand=False, + regex=regex, + )._modin_frame + + split_counts_frame = splits_as_list_frame.append_column( + "split_counts", + array_size( + col( + splits_as_list_frame.data_column_snowflake_quoted_identifiers[0] + ) + ), + ) + + max_count_rows = split_counts_frame.ordered_dataframe.agg( + max_( + col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1]) + ).as_("max_count") + ).collect() + + return max_count_rows[0][0] + + if expand: + cols = output_cols( + col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), + pat, + n, + get_max_splits(), + ) + new_internal_frame = self._modin_frame.project_columns( + list(range(len(cols))), + cols, + ) + else: + new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( + lambda col_name: output_col(col_name, pat, n) + ) + return SnowflakeQueryCompiler(new_internal_frame) def str_rsplit( diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 2c869b10f31..566a2403d98 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -12,7 +12,10 @@ from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_series_equal, + eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker TEST_DATA = [ @@ -367,10 +370,12 @@ def test_str_replace_neg(pat, n, repl, error): snow_ser.str.replace(pat=pat, repl=repl, n=n) -@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize( + "pat", [None, "a", "ab", "abc", "non_occurrence_pat", "|", "%"] +) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=1) -def test_str_split(pat, n): +def test_str_split_expand_false(pat, n): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) eval_snowpark_pandas_result( @@ -380,6 +385,19 @@ def test_str_split(pat, n): ) +@pytest.mark.parametrize("pat", [None, "a", "ab", "abc", "no_occurrence_pat", "|", "%"]) +@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) +@sql_count_checker(query_count=2) +def test_str_split_expand_true(pat, n): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None), + ) + + @pytest.mark.parametrize("regex", [None, True]) @pytest.mark.xfail( reason="Snowflake SQL's split function does not support regex", strict=True @@ -395,21 +413,20 @@ def test_str_split_regex(regex): @pytest.mark.parametrize( - "pat, n, expand, error", + "pat, n, error", [ - ("", 1, False, ValueError), - (re.compile("a"), 1, False, NotImplementedError), - (-2.0, 1, False, NotImplementedError), - ("a", "a", False, NotImplementedError), - ("a", 1, True, NotImplementedError), + ("", 1, ValueError), + (re.compile("a"), 1, NotImplementedError), + (-2.0, 1, NotImplementedError), + ("a", "a", NotImplementedError), ], ) @sql_count_checker(query_count=0) -def test_str_split_neg(pat, n, expand, error): +def test_str_split_neg(pat, n, error): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) with pytest.raises(error): - snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False) + snow_ser.str.split(pat=pat, n=n, expand=False, regex=False) @pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"])