Skip to content

Commit

Permalink
SNOW-1819523: Add support for expand=True in Series.str.split
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-helmeleegy committed Jan 8, 2025
1 parent d92dee9 commit fe00d90
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- %X: Locale’s appropriate time representation.
- %%: A literal '%' character.
- Added support for `Series.between`.
- Added support for `expand=True` in `Series.str.split`.

#### Bug Fixes

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_str_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ the method in the left column.
| ``slice_replace`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, |
| | | `expand` is set, or `regex` is set. |
| | | or `regex` is set. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool |
| | | value. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
dense_rank,
first_value,
floor,
get,
greatest,
hour,
iff,
Expand Down Expand Up @@ -16787,10 +16788,6 @@ def str_split(
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support non-str 'pat' argument"
)
if expand:
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support 'expand' argument"
)
if regex:
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support 'regex' argument"
Expand Down Expand Up @@ -16881,9 +16878,69 @@ def output_col(
)
return self._replace_non_str(column, new_col)

new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
lambda col_name: output_col(col_name, pat, n)
)
def output_cols(
column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int
) -> list[SnowparkColumn]:
col = output_col(column, pat, n)
final_n_cols = 0
if np.isnan(n):
# Follow pandas behavior
final_n_cols = 1
elif n <= 0:
final_n_cols = max_n_cols
else:
final_n_cols = min(n + 1, max_n_cols)

return [
iff(
array_size(col) > pandas_lit(i),
get(col, pandas_lit(i)),
pandas_lit(None),
)
for i in range(final_n_cols)
]

def max_n_cols() -> int:
splits_as_list_frame = self.str_split(
pat=pat,
n=-1,
expand=False,
regex=regex,
)._modin_frame

split_counts_frame = splits_as_list_frame.append_column(
"split_counts",
array_size(
col(
splits_as_list_frame.data_column_snowflake_quoted_identifiers[0]
)
),
)

max_count_rows = split_counts_frame.ordered_dataframe.agg(
max_(
col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1])
).as_("max_count")
).collect()

return max_count_rows[0][0]

if expand:
cols = output_cols(
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
pat,
n,
max_n_cols(),
)
new_internal_frame = self._modin_frame.project_columns(
[f"{i}" for i in range(len(cols))],
cols,
)
else:
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
lambda col_name: output_col(col_name, pat, n)
)

return SnowflakeQueryCompiler(new_internal_frame)

def str_rsplit(
Expand Down
40 changes: 30 additions & 10 deletions tests/integ/modin/series/test_str_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

from snowflake.snowpark._internal.utils import TempObjectType
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result
from tests.integ.modin.utils import (
assert_frame_equal,
assert_series_equal,
eval_snowpark_pandas_result,
)
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker

TEST_DATA = [
Expand Down Expand Up @@ -370,7 +374,7 @@ def test_str_replace_neg(pat, n, repl, error):
@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=1)
def test_str_split(pat, n):
def test_str_split_expand_false(pat, n):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
eval_snowpark_pandas_result(
Expand All @@ -380,6 +384,23 @@ def test_str_split(pat, n):
)


@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=2)
def test_str_split_expand_true(pat, n):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None)
snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None)
# Currently Snowpark pandas uses an Index object with string values for columns,
# while native pandas uses a RangeIndex.
# So we make sure that all corresponding values in the two columns objects are identical
# (after casting from string to int).
assert all(snow_df.columns.astype(int).values == native_df.columns.values)
snow_df.columns = native_df.columns
assert_frame_equal(snow_df, native_df, check_dtype=False)


@pytest.mark.parametrize("regex", [None, True])
@pytest.mark.xfail(
reason="Snowflake SQL's split function does not support regex", strict=True
Expand All @@ -395,21 +416,20 @@ def test_str_split_regex(regex):


@pytest.mark.parametrize(
"pat, n, expand, error",
"pat, n, error",
[
("", 1, False, ValueError),
(re.compile("a"), 1, False, NotImplementedError),
(-2.0, 1, False, NotImplementedError),
("a", "a", False, NotImplementedError),
("a", 1, True, NotImplementedError),
("", 1, ValueError),
(re.compile("a"), 1, NotImplementedError),
(-2.0, 1, NotImplementedError),
("a", "a", NotImplementedError),
],
)
@sql_count_checker(query_count=0)
def test_str_split_neg(pat, n, expand, error):
def test_str_split_neg(pat, n, error):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
with pytest.raises(error):
snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False)
snow_ser.str.split(pat=pat, n=n, expand=False, regex=False)


@pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"])
Expand Down

0 comments on commit fe00d90

Please sign in to comment.