diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 45384ace047..aef89e7f20e 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16835,6 +16835,12 @@ def output_col( if np.isnan(n): # Follow pandas behavior return pandas_lit(np.nan) + elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which seems to leave the input column as is + # whenever the above condition is satisfied. + new_col = iff( + column.is_null(), pandas_lit(None), array_construct(column) + ) elif n <= 0: # If all possible splits are requested, we just use SQL's split function. new_col = builtin("split")(new_col, pandas_lit(new_pat)) @@ -16879,17 +16885,32 @@ def output_col( return self._replace_non_str(column, new_col) def output_cols( - column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int + column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int ) -> list[SnowparkColumn]: + """ + Returns the list of columns that the input column will be split into. + This is only used when expand=True. + Args: + column: input column + pat: string to split on + n: limit on the number of output splits + max_splits: maximum number of achievable splits across all values in the input column + """ col = output_col(column, pat, n) - final_n_cols = 0 + final_splits = 0 + if np.isnan(n): # Follow pandas behavior - final_n_cols = 1 + final_splits = 1 elif n <= 0: - final_n_cols = max_n_cols + final_splits = max_splits else: - final_n_cols = min(n + 1, max_n_cols) + final_splits = min(n + 1, max_splits) + + if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: + # Follow pandas behavior, which seems to leave the input column as is + # whenever the above condition is satisfied. + final_splits = 1 return [ iff( @@ -16897,10 +16918,14 @@ def output_cols( get(col, pandas_lit(i)), pandas_lit(None), ) - for i in range(final_n_cols) + for i in range(final_splits) ] - def max_n_cols() -> int: + def get_max_splits() -> int: + """ + Returns the maximum number of splits achievable + across all values stored in the input column. + """ splits_as_list_frame = self.str_split( pat=pat, n=-1, @@ -16930,10 +16955,10 @@ def max_n_cols() -> int: col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), pat, n, - max_n_cols(), + get_max_splits(), ) new_internal_frame = self._modin_frame.project_columns( - [f"{i}" for i in range(len(cols))], + list(range(len(cols))), cols, ) else: diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 572c10c188e..f0ebe7d23ec 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -13,7 +13,6 @@ from snowflake.snowpark._internal.utils import TempObjectType import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.utils import ( - assert_frame_equal, assert_series_equal, eval_snowpark_pandas_result, ) @@ -371,7 +370,7 @@ def test_str_replace_neg(pat, n, repl, error): snow_ser.str.replace(pat=pat, repl=repl, n=n) -@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("pat", [None, "a", "ab", "non_occurrence_pat", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=1) def test_str_split_expand_false(pat, n): @@ -384,21 +383,17 @@ def test_str_split_expand_false(pat, n): ) -@pytest.mark.parametrize("pat", [None, "a", "|", "%"]) +@pytest.mark.parametrize("pat", [None, "a", "ab", "no_occurrence_pat", "|", "%"]) @pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2]) @sql_count_checker(query_count=2) def test_str_split_expand_true(pat, n): native_ser = native_pd.Series(TEST_DATA) snow_ser = pd.Series(native_ser) - native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None) - snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None) - # Currently Snowpark pandas uses an Index object with string values for columns, - # while native pandas uses a RangeIndex. - # So we make sure that all corresponding values in the two columns objects are identical - # (after casting from string to int). - assert all(snow_df.columns.astype(int).values == native_df.columns.values) - snow_df.columns = native_df.columns - assert_frame_equal(snow_df, native_df, check_dtype=False) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None), + ) @pytest.mark.parametrize("regex", [None, True])