From fe00d9055eb29b306ec0108666174d58c2c549a2 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Tue, 7 Jan 2025 17:18:36 -0800 Subject: [PATCH] SNOW-1819523: Add support for expand=True in Series.str.split --- CHANGELOG.md | 1 + .../modin/supported/series_str_supported.rst | 2 +- .../compiler/snowflake_query_compiler.py | 71 +++++++++++++++++-- tests/integ/modin/series/test_str_accessor.py | 40 ++++++++--- 4 files changed, 96 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3511e3f9c..2a55b5cb5ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ - %X: Locale’s appropriate time representation. - %%: A literal '%' character. - Added support for `Series.between`. +- Added support for `expand=True` in `Series.str.split`. #### Bug Fixes diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index 702c699f153..7c96a0ac362 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -119,7 +119,7 @@ the method in the left column. | ``slice_replace`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, | -| | | `expand` is set, or `regex` is set. | +| | | or `regex` is set. | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool | | | | value. | diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index e483d6aa6f8..45384ace047 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -105,6 +105,7 @@ dense_rank, first_value, floor, + get, greatest, hour, iff, @@ -16787,10 +16788,6 @@ def str_split( ErrorMessage.not_implemented( "Snowpark pandas doesn't support non-str 'pat' argument" ) - if expand: - ErrorMessage.not_implemented( - "Snowpark pandas doesn't support 'expand' argument" - ) if regex: ErrorMessage.not_implemented( "Snowpark pandas doesn't support 'regex' argument" @@ -16881,9 +16878,69 @@ def output_col( ) return self._replace_non_str(column, new_col) - new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( - lambda col_name: output_col(col_name, pat, n) - ) + def output_cols( + column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int + ) -> list[SnowparkColumn]: + col = output_col(column, pat, n) + final_n_cols = 0 + if np.isnan(n): + # Follow pandas behavior + final_n_cols = 1 + elif n <= 0: + final_n_cols = max_n_cols + else: + final_n_cols = min(n + 1, max_n_cols) + + return [ + iff( + array_size(col) > pandas_lit(i), + get(col, pandas_lit(i)), + pandas_lit(None), + ) + for i in range(final_n_cols) + ] + + def max_n_cols() -> int: + splits_as_list_frame = self.str_split( + pat=pat, + n=-1, + expand=False, + regex=regex, + )._modin_frame + + split_counts_frame = splits_as_list_frame.append_column( + "split_counts", + array_size( + col( + splits_as_list_frame.data_column_snowflake_quoted_identifiers[0] + ) + ), + ) + + max_count_rows = split_counts_frame.ordered_dataframe.agg( + max_( + col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1]) + ).as_("max_count") + ).collect() + + return max_count_rows[0][0] + + if expand: + cols = output_cols( + col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), + pat, + n, + max_n_cols(), + ) + new_internal_frame = self._modin_frame.project_columns( + [f"{i}" for i in range(len(cols))], + cols, + ) + else: + new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( + lambda col_name: output_col(col_name, pat, n) + ) + return SnowflakeQueryCompiler(new_internal_frame) def str_rsplit( diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 2dee5ae3232..572c10c188e 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -12,7 +12,11 @@ from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + assert_series_equal, + eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker TEST_DATA = [ @@ -370,7 +374,7 @@ def test_str_replace_neg(pat, n, repl, error): @pytest.mark.parametrize("pat", [None, "a", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=1) -def test_str_split(pat, n): +def test_str_split_expand_false(pat, n): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) eval_snowpark_pandas_result( @@ -380,6 +384,23 @@ def test_str_split(pat, n): ) +@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) +@sql_count_checker(query_count=2) +def test_str_split_expand_true(pat, n): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None) + snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None) + # Currently Snowpark pandas uses an Index object with string values for columns, + # while native pandas uses a RangeIndex. + # So we make sure that all corresponding values in the two columns objects are identical + # (after casting from string to int). + assert all(snow_df.columns.astype(int).values == native_df.columns.values) + snow_df.columns = native_df.columns + assert_frame_equal(snow_df, native_df, check_dtype=False) + + @pytest.mark.parametrize("regex", [None, True]) @pytest.mark.xfail( reason="Snowflake SQL's split function does not support regex", strict=True @@ -395,21 +416,20 @@ def test_str_split_regex(regex): @pytest.mark.parametrize( - "pat, n, expand, error", + "pat, n, error", [ - ("", 1, False, ValueError), - (re.compile("a"), 1, False, NotImplementedError), - (-2.0, 1, False, NotImplementedError), - ("a", "a", False, NotImplementedError), - ("a", 1, True, NotImplementedError), + ("", 1, ValueError), + (re.compile("a"), 1, NotImplementedError), + (-2.0, 1, NotImplementedError), + ("a", "a", NotImplementedError), ], ) @sql_count_checker(query_count=0) -def test_str_split_neg(pat, n, expand, error): +def test_str_split_neg(pat, n, error): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) with pytest.raises(error): - snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False) + snow_ser.str.split(pat=pat, n=n, expand=False, regex=False) @pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"])