Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1819523: Add support for expand=True in Series.str.split #2832

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
- %X: Locale’s appropriate time representation.
- %%: A literal '%' character.
- Added support for `Series.between`.
- Added support for `expand=True` in `Series.str.split`.

#### Bug Fixes

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_str_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ the method in the left column.
| ``slice_replace`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``split`` | P | ``N`` if `pat` is non-string, `n` is non-numeric, |
| | | `expand` is set, or `regex` is set. |
| | | or `regex` is set. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``startswith`` | P | ``N`` if the `na` parameter is set to a non-bool |
| | | value. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
dense_rank,
first_value,
floor,
get,
greatest,
hour,
iff,
Expand Down Expand Up @@ -16787,10 +16788,6 @@ def str_split(
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support non-str 'pat' argument"
)
if expand:
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support 'expand' argument"
)
if regex:
ErrorMessage.not_implemented(
"Snowpark pandas doesn't support 'regex' argument"
Expand Down Expand Up @@ -16881,9 +16878,69 @@ def output_col(
)
return self._replace_non_str(column, new_col)

new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
lambda col_name: output_col(col_name, pat, n)
)
def output_cols(
column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
) -> list[SnowparkColumn]:
col = output_col(column, pat, n)
final_n_cols = 0
if np.isnan(n):
# Follow pandas behavior
final_n_cols = 1
elif n <= 0:
final_n_cols = max_n_cols
else:
final_n_cols = min(n + 1, max_n_cols)

return [
iff(
array_size(col) > pandas_lit(i),
get(col, pandas_lit(i)),
pandas_lit(None),
)
for i in range(final_n_cols)
]

def max_n_cols() -> int:
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
splits_as_list_frame = self.str_split(
pat=pat,
n=-1,
expand=False,
regex=regex,
)._modin_frame

split_counts_frame = splits_as_list_frame.append_column(
"split_counts",
array_size(
col(
splits_as_list_frame.data_column_snowflake_quoted_identifiers[0]
)
),
)

max_count_rows = split_counts_frame.ordered_dataframe.agg(
max_(
col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1])
).as_("max_count")
).collect()

return max_count_rows[0][0]

if expand:
cols = output_cols(
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
pat,
n,
max_n_cols(),
)
new_internal_frame = self._modin_frame.project_columns(
[f"{i}" for i in range(len(cols))],
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
cols,
)
else:
new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns(
lambda col_name: output_col(col_name, pat, n)
)

return SnowflakeQueryCompiler(new_internal_frame)

def str_rsplit(
Expand Down
40 changes: 30 additions & 10 deletions tests/integ/modin/series/test_str_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@

from snowflake.snowpark._internal.utils import TempObjectType
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result
from tests.integ.modin.utils import (
assert_frame_equal,
assert_series_equal,
eval_snowpark_pandas_result,
)
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker

TEST_DATA = [
Expand Down Expand Up @@ -370,7 +374,7 @@ def test_str_replace_neg(pat, n, repl, error):
@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=1)
def test_str_split(pat, n):
def test_str_split_expand_false(pat, n):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
eval_snowpark_pandas_result(
Expand All @@ -380,6 +384,23 @@ def test_str_split(pat, n):
)


@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=2)
def test_str_split_expand_true(pat, n):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None)
snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None)
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved
# Currently Snowpark pandas uses an Index object with string values for columns,
# while native pandas uses a RangeIndex.
# So we make sure that all corresponding values in the two columns objects are identical
# (after casting from string to int).
assert all(snow_df.columns.astype(int).values == native_df.columns.values)
snow_df.columns = native_df.columns
assert_frame_equal(snow_df, native_df, check_dtype=False)
sfc-gh-helmeleegy marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("regex", [None, True])
@pytest.mark.xfail(
reason="Snowflake SQL's split function does not support regex", strict=True
Expand All @@ -395,21 +416,20 @@ def test_str_split_regex(regex):


@pytest.mark.parametrize(
"pat, n, expand, error",
"pat, n, error",
[
("", 1, False, ValueError),
(re.compile("a"), 1, False, NotImplementedError),
(-2.0, 1, False, NotImplementedError),
("a", "a", False, NotImplementedError),
("a", 1, True, NotImplementedError),
("", 1, ValueError),
(re.compile("a"), 1, NotImplementedError),
(-2.0, 1, NotImplementedError),
("a", "a", NotImplementedError),
],
)
@sql_count_checker(query_count=0)
def test_str_split_neg(pat, n, expand, error):
def test_str_split_neg(pat, n, error):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
with pytest.raises(error):
snow_ser.str.split(pat=pat, n=n, expand=expand, regex=False)
snow_ser.str.split(pat=pat, n=n, expand=False, regex=False)


@pytest.mark.parametrize("func", ["isdigit", "islower", "isupper", "lower", "upper"])
Expand Down
Loading