From fe00d9055eb29b306ec0108666174d58c2c549a2 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Tue, 7 Jan 2025 17:18:36 -0800 Subject: [PATCH 1/3] SNOW-1819523: Add support for expand=True in Series.str.split --- CHANGELOG.md | 1 + .../modin/supported/series_str_supported.rst | 2 +- .../compiler/snowflake_query_compiler.py | 71 +++++++++++++++++-- tests/integ/modin/series/test_str_accessor.py | 40 ++++++++--- 4 files changed, 96 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3511e3f9c..2a55b5cb5ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,7 @@ - %X: Locale’s appropriate time representation. - %%: A literal '%' character. - Added support for `Series.between`. +- Added support for `expand=True` in `Series.str.split`. #### Bug Fixes diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index 702c699f153..7c96a0ac362 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -119,7 +119,7 @@ the method in the left column. | ``slice_replace`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, | -| | | `expand` is set, or `regex` is set. | +| | | or `regex` is set. | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool | | | | value. | diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index e483d6aa6f8..45384ace047 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -105,6 +105,7 @@ dense_rank, first_value, floor, + get, greatest, hour, iff, @@ -16787,10 +16788,6 @@ def str_split( ErrorMessage.not_implemented( "Snowpark pandas doesn't support non-str 'pat' argument" ) - if expand: - ErrorMessage.not_implemented( - "Snowpark pandas doesn't support 'expand' argument" - ) if regex: ErrorMessage.not_implemented( "Snowpark pandas doesn't support 'regex' argument" @@ -16881,9 +16878,69 @@ def output_col( ) return self._replace_non_str(column, new_col) - new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( - lambda col_name: output_col(col_name, pat, n) - ) + def output_cols( + column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int + ) -> list[SnowparkColumn]: + col = output_col(column, pat, n) + final_n_cols = 0 + if np.isnan(n): + # Follow pandas behavior + final_n_cols = 1 + elif n <= 0: + final_n_cols = max_n_cols + else: + final_n_cols = min(n + 1, max_n_cols) + + return [ + iff( + array_size(col) > pandas_lit(i), + get(col, pandas_lit(i)), + pandas_lit(None), + ) + for i in range(final_n_cols) + ] + + def max_n_cols() -> int: + splits_as_list_frame = self.str_split( + pat=pat, + n=-1, + expand=False, + regex=regex, + )._modin_frame + + split_counts_frame = splits_as_list_frame.append_column( + "split_counts", + array_size( + col( + splits_as_list_frame.data_column_snowflake_quoted_identifiers[0] + ) + ), + ) + + max_count_rows = split_counts_frame.ordered_dataframe.agg( + max_( + col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1]) + ).as_("max_count") + ).collect() + + return max_count_rows[0][0] + + if expand: + cols = output_cols( + col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), + pat, + n, + max_n_cols(), + ) + new_internal_frame = self._modin_frame.project_columns( + [f"{i}" for i in range(len(cols))], + cols, + ) + else: + new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( + lambda col_name: output_col(col_name, pat, n) + ) + return SnowflakeQueryCompiler(new_internal_frame) def str_rsplit( diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 2dee5ae3232..572c10c188e 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -12,7 +12,11 @@ from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result +from tests.integ.modin.utils import ( + assert_frame_equal, + assert_series_equal, + eval_snowpark_pandas_result, +) from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker TEST_DATA = [ @@ -370,7 +374,7 @@ def test_str_replace_neg(pat, n, repl, error): @pytest.mark.parametrize("pat", [None, "a", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=1) -def test_str_split(pat, n): +def test_str_split_expand_false(pat, n): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) eval_snowpark_pandas_result( @@ -380,6 +384,23 @@ def test_str_split(pat, n): ) +@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) +@sql_count_checker(query_count=2) +def test_str_split_expand_true(pat, n): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None) + snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None) + # Currently Snowpark pandas uses an Index object with string values for columns, + # while native pandas uses a RangeIndex. + # So we make sure that all corresponding values in the two columns objects are identical + # (after casting from string to int). + assert all(snow_df.columns.astype(int).values == native_df.columns.values) + snow_df.columns = native_df.columns + assert_frame_equal(snow_df, native_df, check_dtype=False) + + @pytest.mark.parametrize("regex", [None, True]) @pytest.mark.xfail( reason="Snowflake SQL's split function does not support regex", strict=True @@ -395,21 +416,20 @@ def test_str_split_regex(regex): @pytest.mark.parametrize( - "pat, n, expand, error", + "pat, n, error", [ - ("", 1, False, ValueError), - (re.compile("a"), 1, False, NotImplementedError), - (-2.0, 1, False, NotImplementedError), - ("a", "a", False, NotImplementedError), - ("a", 1, True, NotImplementedError), + ("", 1, ValueError), + (re.compile("a"), 1, NotImplementedError), + (-2.0, 1, NotImplementedError), + ("a", "a", NotImplementedError), ], ) @sql_count_checker(query_count=0) -def test_str_split_neg(pat, n, expand, error): +def test_str_split_neg(pat, n, error): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) with pytest.raises(error): - snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False) + snow_ser.str.split(pat=pat, n=n, expand=False, regex=False) @pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"]) From 2b668f208428989c3b02a0e1a9d4a46cc7375f76 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Thu, 9 Jan 2025 02:00:23 -0800 Subject: [PATCH 2/3] address comments --- .../compiler/snowflake_query_compiler.py | 43 +++++++++++++++---- tests/integ/modin/series/test_str_accessor.py | 19 +++----- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 45384ace047..aef89e7f20e 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16835,6 +16835,12 @@ def output_col( if np.isnan(n): # Follow pandas behavior return pandas_lit(np.nan) + elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which seems to leave the input column as is + # whenever the above condition is satisfied. + new_col = iff( + column.is_null(), pandas_lit(None), array_construct(column) + ) elif n <= 0: # If all possible splits are requested, we just use SQL's split function. new_col = builtin("split")(new_col, pandas_lit(new_pat)) @@ -16879,17 +16885,32 @@ def output_col( return self._replace_non_str(column, new_col) def output_cols( - column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int + column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int ) -> list[SnowparkColumn]: + """ + Returns the list of columns that the input column will be split into. + This is only used when expand=True. + Args: + column: input column + pat: string to split on + n: limit on the number of output splits + max_splits: maximum number of achievable splits across all values in the input column + """ col = output_col(column, pat, n) - final_n_cols = 0 + final_splits = 0 + if np.isnan(n): # Follow pandas behavior - final_n_cols = 1 + final_splits = 1 elif n <= 0: - final_n_cols = max_n_cols + final_splits = max_splits else: - final_n_cols = min(n + 1, max_n_cols) + final_splits = min(n + 1, max_splits) + + if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which seems to leave the input column as is + # whenever the above condition is satisfied. + final_splits = 1 return [ iff( @@ -16897,10 +16918,14 @@ def output_cols( get(col, pandas_lit(i)), pandas_lit(None), ) - for i in range(final_n_cols) + for i in range(final_splits) ] - def max_n_cols() -> int: + def get_max_splits() -> int: + """ + Returns the maximum number of splits achievable + across all values stored in the input column. + """ splits_as_list_frame = self.str_split( pat=pat, n=-1, @@ -16930,10 +16955,10 @@ def max_n_cols() -> int: col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), pat, n, - max_n_cols(), + get_max_splits(), ) new_internal_frame = self._modin_frame.project_columns( - [f"{i}" for i in range(len(cols))], + list(range(len(cols))), cols, ) else: diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 572c10c188e..f0ebe7d23ec 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -13,7 +13,6 @@ from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.utils import ( - assert_frame_equal, assert_series_equal, eval_snowpark_pandas_result, ) @@ -371,7 +370,7 @@ def test_str_replace_neg(pat, n, repl, error): snow_ser.str.replace(pat=pat, repl=repl, n=n) -@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("pat", [None, "a", "ab", "non_occurrence_pat", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=1) def test_str_split_expand_false(pat, n): @@ -384,21 +383,17 @@ def test_str_split_expand_false(pat, n): ) -@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("pat", [None, "a", "ab", "no_occurrence_pat", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=2) def test_str_split_expand_true(pat, n): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) - native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None) - snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None) - # Currently Snowpark pandas uses an Index object with string values for columns, - # while native pandas uses a RangeIndex. - # So we make sure that all corresponding values in the two columns objects are identical - # (after casting from string to int). - assert all(snow_df.columns.astype(int).values == native_df.columns.values) - snow_df.columns = native_df.columns - assert_frame_equal(snow_df, native_df, check_dtype=False) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None), + ) @pytest.mark.parametrize("regex", [None, True]) From 38c3276578ad6a6c7366d7a0d50cc3761d1fe113 Mon Sep 17 00:00:00 2001 From: Hazem Elmeleegy Date: Thu, 9 Jan 2025 10:51:39 -0800 Subject: [PATCH 3/3] address comments --- .../plugin/compiler/snowflake_query_compiler.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 0dcea242b76..6861a432750 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16895,10 +16895,15 @@ def output_cols( Returns the list of columns that the input column will be split into. This is only used when expand=True. Args: - column: input column - pat: string to split on - n: limit on the number of output splits - max_splits: maximum number of achievable splits across all values in the input column + column : SnowparkColumn + Input column + pat : str + String to split on + n : int + Limit on the number of output splits + max_splits : int + Maximum number of achievable splits across all values in the input column. + This is needed to be able to pad rows with fewer splits than desired with nulls. """ col = output_col(column, pat, n) final_splits = 0