Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-helmeleegy committed Jan 9, 2025
1 parent fe00d90 commit 2b668f2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16835,6 +16835,12 @@ def output_col(
if np.isnan(n):
# Follow pandas behavior
return pandas_lit(np.nan)
elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
# Follow pandas behavior, which seems to leave the input column as is
# whenever the above condition is satisfied.
new_col = iff(
column.is_null(), pandas_lit(None), array_construct(column)
)
elif n <= 0:
# If all possible splits are requested, we just use SQL's split function.
new_col = builtin("split")(new_col, pandas_lit(new_pat))
Expand Down Expand Up @@ -16879,28 +16885,47 @@ def output_col(
return self._replace_non_str(column, new_col)

def output_cols(
column: SnowparkColumn, pat: Optional[str], n: int, max_n_cols: int
column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int
) -> list[SnowparkColumn]:
"""
Returns the list of columns that the input column will be split into.
This is only used when expand=True.
Args:
column: input column
pat: string to split on
n: limit on the number of output splits
max_splits: maximum number of achievable splits across all values in the input column
"""
col = output_col(column, pat, n)
final_n_cols = 0
final_splits = 0

if np.isnan(n):
# Follow pandas behavior
final_n_cols = 1
final_splits = 1
elif n <= 0:
final_n_cols = max_n_cols
final_splits = max_splits
else:
final_n_cols = min(n + 1, max_n_cols)
final_splits = min(n + 1, max_splits)

if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1:
# Follow pandas behavior, which seems to leave the input column as is
# whenever the above condition is satisfied.
final_splits = 1

return [
iff(
array_size(col) > pandas_lit(i),
get(col, pandas_lit(i)),
pandas_lit(None),
)
for i in range(final_n_cols)
for i in range(final_splits)
]

def max_n_cols() -> int:
def get_max_splits() -> int:
"""
Returns the maximum number of splits achievable
across all values stored in the input column.
"""
splits_as_list_frame = self.str_split(
pat=pat,
n=-1,
Expand Down Expand Up @@ -16930,10 +16955,10 @@ def max_n_cols() -> int:
col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]),
pat,
n,
max_n_cols(),
get_max_splits(),
)
new_internal_frame = self._modin_frame.project_columns(
[f"{i}" for i in range(len(cols))],
list(range(len(cols))),
cols,
)
else:
Expand Down
19 changes: 7 additions & 12 deletions tests/integ/modin/series/test_str_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from snowflake.snowpark._internal.utils import TempObjectType
import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.utils import (
assert_frame_equal,
assert_series_equal,
eval_snowpark_pandas_result,
)
Expand Down Expand Up @@ -371,7 +370,7 @@ def test_str_replace_neg(pat, n, repl, error):
snow_ser.str.replace(pat=pat, repl=repl, n=n)


@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
@pytest.mark.parametrize("pat", [None, "a", "ab", "non_occurrence_pat", "|", "%"])
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=1)
def test_str_split_expand_false(pat, n):
Expand All @@ -384,21 +383,17 @@ def test_str_split_expand_false(pat, n):
)


@pytest.mark.parametrize("pat", [None, "a", "|", "%"])
@pytest.mark.parametrize("pat", [None, "a", "ab", "no_occurrence_pat", "|", "%"])
@pytest.mark.parametrize("n", [None, np.nan, 3, 2, 1, 0, -1, -2])
@sql_count_checker(query_count=2)
def test_str_split_expand_true(pat, n):
native_ser = native_pd.Series(TEST_DATA)
snow_ser = pd.Series(native_ser)
native_df = native_ser.str.split(pat=pat, n=n, expand=True, regex=None)
snow_df = snow_ser.str.split(pat=pat, n=n, expand=True, regex=None)
# Currently Snowpark pandas uses an Index object with string values for columns,
# while native pandas uses a RangeIndex.
# So we make sure that all corresponding values in the two columns objects are identical
# (after casting from string to int).
assert all(snow_df.columns.astype(int).values == native_df.columns.values)
snow_df.columns = native_df.columns
assert_frame_equal(snow_df, native_df, check_dtype=False)
eval_snowpark_pandas_result(
snow_ser,
native_ser,
lambda ser: ser.str.split(pat=pat, n=n, expand=True, regex=None),
)


@pytest.mark.parametrize("regex", [None, True])
Expand Down

0 comments on commit 2b668f2

Please sign in to comment.