diff --git a/.github/workflows/daily_precommit.yml b/.github/workflows/daily_precommit.yml index f1277ff7148..d0cfbc2db59 100644 --- a/.github/workflows/daily_precommit.yml +++ b/.github/workflows/daily_precommit.yml @@ -546,75 +546,6 @@ jobs: .tox/.coverage .tox/coverage.xml - test-snowpark-disable-multithreading-mode: - name: Test Snowpark Multithreading Disabled py-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} - needs: build - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest-64-cores] - python-version: ["3.9"] - cloud-provider: [aws] - steps: - - name: Checkout Code - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - name: Decrypt parameters.py - shell: bash - run: .github/scripts/decrypt_parameters.sh - env: - PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }} - CLOUD_PROVIDER: ${{ matrix.cloud-provider }} - - name: Install protoc - shell: bash - run: .github/scripts/install_protoc.sh - - name: Download wheel(s) - uses: actions/download-artifact@v4 - with: - name: wheel - path: dist - - name: Show wheels downloaded - run: ls -lh dist - shell: bash - - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel - - name: Install tox - run: python -m pip install tox - - name: Run tests (excluding doctests) - run: python -m tox -e "py${PYTHON_VERSION/\./}-notmultithreaded-ci" - env: - PYTHON_VERSION: ${{ matrix.python-version }} - cloud_provider: ${{ matrix.cloud-provider }} - PYTEST_ADDOPTS: --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - shell: bash - - name: Run local tests - run: python -m tox -e "py${PYTHON_VERSION/\./}-localnotmultithreaded-ci" - env: - PYTHON_VERSION: ${{ matrix.python-version }} - cloud_provider: ${{ matrix.cloud-provider }} - PYTEST_ADDOPTS: --color=yes --tb=short - TOX_PARALLEL_NO_SPINNER: 1 - shell: bash - - name: Combine coverages - run: python -m tox -e coverage --skip-missing-interpreters false - shell: bash - env: - SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1 - - uses: actions/upload-artifact@v4 - with: - include-hidden-files: true - name: coverage_${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}-snowpark-multithreading - path: | - .tox/.coverage - .tox/coverage.xml - combine-coverage: if: ${{ success() || failure() }} name: Combine coverage diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ebf01e80d1..63314dc52d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,11 +19,13 @@ #### Improvements - Updated README.md to include instructions on how to verify package signatures using `cosign`. +- Added an option `keep_column_order` for keeping original column order in `DataFrame.with_column` and `DataFrame.with_columns`. #### Bug Fixes - Fixed a bug in local testing mode that caused a column to contain None when it should contain 0 -- Fixed a bug in StructField.from_json that prevented TimestampTypes with tzinfo from being parsed correctly. +- Fixed a bug in `StructField.from_json` that prevented TimestampTypes with tzinfo from being parsed correctly. +- Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type. ### Snowpark pandas API Updates @@ -49,6 +51,7 @@ - %X: Locale’s appropriate time representation. - %%: A literal '%' character. - Added support for `Series.between`. +- Added support for `include_groups=False` in `DataFrameGroupBy.apply`. #### Bug Fixes diff --git a/docs/source/modin/supported/groupby_supported.rst b/docs/source/modin/supported/groupby_supported.rst index dde67fbdc1c..89a61c95266 100644 --- a/docs/source/modin/supported/groupby_supported.rst +++ b/docs/source/modin/supported/groupby_supported.rst @@ -39,8 +39,8 @@ Function application +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``apply`` | P | ``axis`` other than 0 is not | ``Y`` if the following are true, otherwise ``N``: | | | | implemented. | - ``func`` is a callable that always returns | -| | | ``include_groups = False`` is | either a pandas DataFrame, a pandas Series, or | -| | | not implemented. | objects that are neither DataFrame nor Series. | +| | | | either a pandas DataFrame, a pandas Series, or | +| | | | objects that are neither DataFrame nor Series. | | | | | - grouping on axis=0 | | | | | - Not applying transform to a dataframe with a | | | | | non-unique index | diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..6b87735de36 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[build-system] +requires = [ + "setuptools", + "protoc-wheel-0==21.1", # Protocol buffer compiler for Snowpark IR + "mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR +] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 24f1dbe15ef..1ea499c6001 100644 --- a/setup.py +++ b/setup.py @@ -58,9 +58,8 @@ "graphviz", # used in plot tests "pytest-assume", # sql counter check "decorator", # sql counter check - "protoc-wheel-0==21.1", # Protocol buffer compiler, for Snowpark IR - "mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR "lxml", # used in read_xml tests + "tox", # used for setting up testing environments ] # read the version diff --git a/src/conftest.py b/src/conftest.py index 58b7cbd481c..8be82abab71 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -41,6 +41,27 @@ def pytest_runtest_makereport(item, call): return tr +# These tests require python packages that are no longer built for python 3.8 +PYTHON_38_SKIPS = { + "snowpark.session.Session.replicate_local_environment", + "snowpark.session.Session.table_function", +} + +DocTestFinder = doctest.DocTestFinder + + +class CustomDocTestFinder(DocTestFinder): + def _find(self, tests, obj, name, module, source_lines, globs, seen): + if name in PYTHON_38_SKIPS and sys.version_info < (3, 9): + return + return DocTestFinder._find( + self, tests, obj, name, module, source_lines, globs, seen + ) + + +doctest.DocTestFinder = CustomDocTestFinder + + # scope is module so that we ensure we delete the session before # moving onto running the tests in the tests dir. Having only one # session is important to certain UDF tests to pass , since they diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index d9209d7abb4..3a9ebf3a832 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -108,7 +108,6 @@ generate_random_alphanumeric, get_copy_into_table_options, is_sql_select_statement, - random_name_for_temp_object, ) from snowflake.snowpark.row import Row from snowflake.snowpark.types import StructType @@ -664,12 +663,8 @@ def large_local_relation_plan( source_plan: Optional[LogicalPlan], schema_query: Optional[str], ) -> SnowflakePlan: - thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled - temp_table_name = ( - f"temp_name_placeholder_{generate_random_alphanumeric()}" - if thread_safe_session_enabled - else random_name_for_temp_object(TempObjectType.TABLE) - ) + temp_table_name = f"temp_name_placeholder_{generate_random_alphanumeric()}" + attributes = [ Attribute(attr.name, attr.datatype, attr.nullable) for attr in output ] @@ -696,9 +691,7 @@ def large_local_relation_plan( Query( create_table_stmt, is_ddl_on_temp_object=True, - temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE) - if thread_safe_session_enabled - else None, + temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE), ), BatchInsertQuery(insert_stmt, data), Query(select_stmt), @@ -1215,7 +1208,6 @@ def read_file( metadata_project: Optional[List[str]] = None, metadata_schema: Optional[List[Attribute]] = None, ): - thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled format_type_options, copy_options = get_copy_into_table_options(options) format_type_options = self._merge_file_format_options( format_type_options, options @@ -1247,8 +1239,6 @@ def read_file( post_queries: List[Query] = [] format_name = self.session.get_fully_qualified_name_if_possible( f"temp_name_placeholder_{generate_random_alphanumeric()}" - if thread_safe_session_enabled - else random_name_for_temp_object(TempObjectType.FILE_FORMAT) ) queries.append( Query( @@ -1262,9 +1252,7 @@ def read_file( is_generated=True, ), is_ddl_on_temp_object=True, - temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT) - if thread_safe_session_enabled - else None, + temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT), ) ) post_queries.append( @@ -1323,8 +1311,6 @@ def read_file( temp_table_name = self.session.get_fully_qualified_name_if_possible( f"temp_name_placeholder_{generate_random_alphanumeric()}" - if thread_safe_session_enabled - else random_name_for_temp_object(TempObjectType.TABLE) ) queries = [ Query( @@ -1337,9 +1323,7 @@ def read_file( is_generated=True, ), is_ddl_on_temp_object=True, - temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE) - if thread_safe_session_enabled - else None, + temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE), ), Query( copy_into_table( diff --git a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py index 78808f8ec60..b7aab6e8c73 100644 --- a/src/snowflake/snowpark/_internal/compiler/plan_compiler.py +++ b/src/snowflake/snowpark/_internal/compiler/plan_compiler.py @@ -204,35 +204,31 @@ def replace_temp_obj_placeholders( To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object here. """ - session = self._plan.session - if session._conn._thread_safe_session_enabled: - # This dictionary will store the mapping between placeholder name and actual temp object name. - placeholders = {} - # Final execution queries - execution_queries = {} - for query_type, query_list in queries.items(): - execution_queries[query_type] = [] - for query in query_list: - # If the query contains a temp object name placeholder, we generate a random - # name for the temp object and add it to the placeholders dictionary. - if query.temp_obj_name_placeholder: - ( - placeholder_name, - temp_obj_type, - ) = query.temp_obj_name_placeholder - placeholders[placeholder_name] = random_name_for_temp_object( - temp_obj_type - ) - - copied_query = copy.copy(query) - for placeholder_name, target_temp_name in placeholders.items(): - # Copy the original query and replace all the placeholder names with the - # actual temp object names. - copied_query.sql = copied_query.sql.replace( - placeholder_name, target_temp_name - ) - - execution_queries[query_type].append(copied_query) - return execution_queries - - return queries + # This dictionary will store the mapping between placeholder name and actual temp object name. + placeholders = {} + # Final execution queries + execution_queries = {} + for query_type, query_list in queries.items(): + execution_queries[query_type] = [] + for query in query_list: + # If the query contains a temp object name placeholder, we generate a random + # name for the temp object and add it to the placeholders dictionary. + if query.temp_obj_name_placeholder: + ( + placeholder_name, + temp_obj_type, + ) = query.temp_obj_name_placeholder + placeholders[placeholder_name] = random_name_for_temp_object( + temp_obj_type + ) + + copied_query = copy.copy(query) + for placeholder_name, target_temp_name in placeholders.items(): + # Copy the original query and replace all the placeholder names with the + # actual temp object names. + copied_query.sql = copied_query.sql.replace( + placeholder_name, target_temp_name + ) + + execution_queries[query_type].append(copied_query) + return execution_queries diff --git a/src/snowflake/snowpark/_internal/server_connection.py b/src/snowflake/snowpark/_internal/server_connection.py index f206b0129b3..9c437e47df2 100644 --- a/src/snowflake/snowpark/_internal/server_connection.py +++ b/src/snowflake/snowpark/_internal/server_connection.py @@ -51,8 +51,6 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import TelemetryClient from snowflake.snowpark._internal.utils import ( - create_rlock, - create_thread_local, escape_quotes, get_application_name, get_version, @@ -173,12 +171,8 @@ def __init__( except TypeError: pass - # thread safe param protection - self._thread_safe_session_enabled = self._get_client_side_session_parameter( - "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False - ) - self._lock = create_rlock(self._thread_safe_session_enabled) - self._thread_store = create_thread_local(self._thread_safe_session_enabled) + self._lock = threading.RLock() + self._thread_store = threading.local() if "password" in self._lower_case_parameters: self._lower_case_parameters["password"] = None diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index 946e4cc33f6..4fae4b6c60b 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -2,12 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import logging +from threading import RLock import weakref from collections import defaultdict from typing import TYPE_CHECKING, Dict from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable -from snowflake.snowpark._internal.utils import create_rlock, is_in_stored_procedure +from snowflake.snowpark._internal.utils import is_in_stored_procedure _logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def __init__(self, session: "Session") -> None: # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) # Lock to protect the ref_count_map - self.lock = create_rlock(session._conn._thread_safe_session_enabled) + self.lock = RLock() def add(self, table: SnowflakeTable) -> None: with self.lock: diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 74cf13cc095..7a38368e8a8 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -378,12 +378,7 @@ def normalize_path(path: str, is_local: bool) -> str: return f"'{path}'" -def warn_session_config_update_in_multithreaded_mode( - config: str, thread_safe_mode_enabled: bool -) -> None: - if not thread_safe_mode_enabled: - return - +def warn_session_config_update_in_multithreaded_mode(config: str) -> None: if threading.active_count() > 1: _logger.warning( "You might have more than one threads sharing the Session object trying to update " @@ -798,47 +793,6 @@ def warning(self, text: str) -> None: self.count += 1 -# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout -class DummyRLock: - """This is a dummy lock that is used in place of threading.Rlock when multithreading is - disabled.""" - - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - def acquire(self, *args, **kwargs): - pass # pragma: no cover - - def release(self, *args, **kwargs): - pass # pragma: no cover - - -class DummyThreadLocal: - """This is a dummy thread local class that is used in place of threading.local when - multithreading is disabled.""" - - pass - - -def create_thread_local( - thread_safe_session_enabled: bool, -) -> Union[threading.local, DummyThreadLocal]: - if thread_safe_session_enabled: - return threading.local() - return DummyThreadLocal() - - -def create_rlock( - thread_safe_session_enabled: bool, -) -> Union[threading.RLock, DummyRLock]: - if thread_safe_session_enabled: - return threading.RLock() - return DummyRLock() - - warning_dict: Dict[str, WarningHelper] = {} diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 21c36fb2200..887e764344e 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -3688,6 +3688,8 @@ def with_column( self, col_name: str, col: Union[Column, TableFunctionCall], + *, + keep_column_order: bool = False, ast_stmt: proto.Expr = None, _emit_ast: bool = True, ) -> "DataFrame": @@ -3730,6 +3732,7 @@ def with_column( Args: col_name: The name of the column to add or replace. col: The :class:`Column` or :class:`table_function.TableFunctionCall` with single column output to add or replace. + keep_column_order: If ``True``, the original order of the columns in the DataFrame is preserved when reaplacing a column. """ if ast_stmt is None and _emit_ast: ast_stmt = self._session._ast_batch.assign() @@ -3738,7 +3741,13 @@ def with_column( build_expr_from_snowpark_column_or_table_fn(expr.col, col) self._set_ast_ref(expr.df) - df = self.with_columns([col_name], [col], _ast_stmt=ast_stmt, _emit_ast=False) + df = self.with_columns( + [col_name], + [col], + keep_column_order=keep_column_order, + _ast_stmt=ast_stmt, + _emit_ast=False, + ) if _emit_ast: df._ast_id = ast_stmt.var_id.bitfield1 @@ -3751,6 +3760,8 @@ def with_columns( self, col_names: List[str], values: List[Union[Column, TableFunctionCall]], + *, + keep_column_order: bool = False, _ast_stmt: proto.Expr = None, _emit_ast: bool = True, ) -> "DataFrame": @@ -3797,6 +3808,7 @@ def with_columns( col_names: A list of the names of the columns to add or replace. values: A list of the :class:`Column` objects or :class:`table_function.TableFunctionCall` object to add or replace. + keep_column_order: If ``True``, the original order of the columns in the DataFrame is preserved when reaplacing a column. """ # Get a list of the new columns and their dedupped values qualified_names = [quote_name(n) for n in col_names] @@ -3837,14 +3849,7 @@ def with_columns( names = col_names[i : i + offset + 1] new_cols.append(col.as_(*names)) - # Get a list of existing column names that are not being replaced - old_cols = [ - Column(field) - for field in self._output - if field.name not in new_column_names - ] - - # AST. + # AST if _ast_stmt is None and _emit_ast: _ast_stmt = self._session._ast_batch.assign() expr = with_src_position( @@ -3856,8 +3861,41 @@ def with_columns( build_expr_from_snowpark_column_or_table_fn(expr.values.add(), value) self._set_ast_ref(expr.df) - # Put it all together - df = self.select([*old_cols, *new_cols], _ast_stmt=_ast_stmt, _emit_ast=False) + # If there's a table function call or keep_column_order=False, + # we do the original "remove old columns and append new ones" logic. + if num_table_func_calls > 0 or not keep_column_order: + old_cols = [ + Column(field) + for field in self._output + if field.name not in new_column_names + ] + final_cols = [*old_cols, *new_cols] + else: + # keep_column_order=True and no table function calls + # Re-insert replaced columns in their original positions if they exist + replaced_map = { + name: new_col for name, new_col in zip(qualified_names, new_cols) + } + final_cols = [] + used = set() # track which new cols we've inserted + + for field in self._output: + field_quoted = quote_name(field.name) + # If this old column name is being replaced, insert the new col at the same position + if field_quoted in replaced_map: + final_cols.append(replaced_map[field_quoted]) + used.add(field_quoted) + else: + # keep the original col + final_cols.append(Column(field)) + + # For any new columns that didn't exist in the old schema, append them at the end + for name, c in replaced_map.items(): + if name not in used: + final_cols.append(c) + + # Construct the final DataFrame + df = self.select(final_cols, _ast_stmt=_ast_stmt, _emit_ast=False) if _emit_ast: df._ast_id = _ast_stmt.var_id.bitfield1 diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index fbb0055e203..2e24105f445 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -224,6 +224,7 @@ StoredProcedureRegistration, ) from snowflake.snowpark.types import ( + ArrayType, DataType, FloatType, PandasDataFrameType, @@ -3561,20 +3562,67 @@ def _concat_ws_ignore_nulls(sep: str, *cols: ColumnOrName) -> Column: |Hello | ----------------------------------------------------- + + >>> df = session.create_dataframe([ + ... (['Hello', 'World', None], None, '!'), + ... (['Hi', 'World', "."], "I'm Dad", '.'), + ... ], schema=['a', 'b', 'c']) + >>> df.select(_concat_ws_ignore_nulls(", ", "a", "b", "c")).show() + ----------------------------------------------------- + |"CONCAT_WS_IGNORE_NULLS(', ', ""A"",""B"",""C"")" | + ----------------------------------------------------- + |Hello, World, ! | + |Hi, World, ., I'm Dad, . | + ----------------------------------------------------- + """ # TODO: SNOW-1831917 create ast columns = [_to_col_if_str(c, "_concat_ws_ignore_nulls") for c in cols] names = ",".join([c.get_name() for c in columns]) - input_column_array = array_construct_compact(*columns, _emit_ast=False) - reduced_result = builtin("reduce", _emit_ast=False)( - input_column_array, - lit("", _emit_ast=False), - sql_expr(f"(l, r) -> l || '{sep}' || r"), - ) - return substring(reduced_result, len(sep) + 1, _emit_ast=False).alias( - f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False - ) + # The implementation of this function is as follows with example input of + # sep = "," and row = [a, NULL], b, NULL, c: + # 1. Cast all columns to array. + # [a, NULL], [b], NULL, [c] + # 2. Combine all arrays into a array of arrays after removing nulls (array_construct_compact). + # [[a, NULL], [b], [c]] + # 3. Flatten the array of arrays into a single array (array_flatten). + # [a, NULL, b, c] + # 4. Filter out nulls (array_remove_nulls). + # [a, b, c] + # 5. Concatenate the non-null values into a single string (concat_strings_with_sep). + # "a,b,c" + + def array_remove_nulls(col: Column) -> Column: + """Expects an array and returns an array with nulls removed.""" + return builtin("filter", _emit_ast=False)( + col, sql_expr("x -> NOT IS_NULL_VALUE(x)", _emit_ast=False) + ) + + def concat_strings_with_sep(col: Column) -> Column: + """ + Expects an array of strings and returns a single string + with the values concatenated with the separator. + """ + return substring( + builtin("reduce", _emit_ast=False)( + col, lit(""), sql_expr(f"(l, r) -> l || '{sep}' || r", _emit_ast=False) + ), + len(sep) + 1, + _emit_ast=False, + ) + + return concat_strings_with_sep( + array_remove_nulls( + array_flatten( + array_construct_compact( + *[c.cast(ArrayType(), _emit_ast=False) for c in columns], + _emit_ast=False, + ), + _emit_ast=False, + ) + ) + ).alias(f"CONCAT_WS_IGNORE_NULLS('{sep}', {names})", _emit_ast=False) @publicapi @@ -3828,6 +3876,19 @@ def date_format( |2022/05/15 10:45:00 | ----------------------- + + Example:: + >>> df = session.sql("select '2023-10-10'::DATE as date_col, '2023-10-10 15:30:00'::TIMESTAMP as timestamp_col") + >>> df.select( + ... date_format('date_col', 'YYYY/MM/DD').as_('formatted_dt'), + ... date_format('timestamp_col', 'YYYY/MM/DD HH:mi:ss').as_('formatted_ts') + ... ).show() + ---------------------------------------- + |"FORMATTED_DT" |"FORMATTED_TS" | + ---------------------------------------- + |2023/10/10 |2023/10/10 15:30:00 | + ---------------------------------------- + """ # AST. @@ -3836,7 +3897,11 @@ def date_format( ast = proto.Expr() build_builtin_fn_apply(ast, "date_format", c, fmt) - ans = to_char(try_cast(c, TimestampType(), _emit_ast=False), fmt, _emit_ast=False) + ans = to_char( + try_cast(to_char(c, _emit_ast=False), TimestampType(), _emit_ast=False), + fmt, + _emit_ast=False, + ) ans._ast = ast return ans diff --git a/src/snowflake/snowpark/mock/_connection.py b/src/snowflake/snowpark/mock/_connection.py index c4ef34bc27b..bc29bb905dd 100644 --- a/src/snowflake/snowpark/mock/_connection.py +++ b/src/snowflake/snowpark/mock/_connection.py @@ -6,6 +6,7 @@ import functools import json import logging +import threading import uuid from copy import copy from decimal import Decimal @@ -30,7 +31,6 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE from snowflake.snowpark._internal.utils import ( - create_rlock, is_in_stored_procedure, result_set_to_rows, ) @@ -297,11 +297,7 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None: self._cursor = Mock() self._options = options or {} session_params = self._options.get("session_parameters", {}) - # thread safe param protection - self._thread_safe_session_enabled = session_params.get( - "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False - ) - self._lock = create_rlock(self._thread_safe_session_enabled) + self._lock = threading.RLock() self._lower_case_parameters = {} self._query_listeners = set() self._telemetry_client = Mock() diff --git a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py index 0bd2fa9991c..42974b7c21b 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py @@ -842,15 +842,13 @@ def convert_numpy_int_result_to_int(value: Any) -> Any: DUMMY_BOOL_INPUT = native_pd.Series([False, True]) -DUMMY_INT_INPUT = native_pd.Series( - [-37, -9, -2, -1, 0, 2, 3, 5, 7, 9, 13, 16, 20] - + np.power(10, np.arange(19)).tolist() - + np.multiply(-1, np.power(10, np.arange(19))).tolist() -) +# Note: we use only small dummy values here to avoid the risk of certain callables +# taking a long time to execute (where execution time is a function of the input value). +# As a downside this reduces diversity in input data so will reduce the effectiveness +# type inference framework in some rare cases. +DUMMY_INT_INPUT = native_pd.Series([-37, -9, -2, -1, 0, 2, 3, 5, 7, 9, 13, 16, 20, 101]) DUMMY_FLOAT_INPUT = native_pd.Series( - [-9.9, -2.2, -1.0, 0.0, 0.5, 0.33, None, 0.99, 2.0, 3.0, 5.0, 7.7, 9.898989] - + np.power(10.1, np.arange(19)).tolist() - + np.multiply(-1.0, np.power(10.1, np.arange(19))).tolist() + [-9.9, -2.2, -1.0, 0.0, 0.5, 0.33, None, 0.99, 2.0, 3.0, 5.0, 7.7, 9.898989, 100.1] ) DUMMY_STRING_INPUT = native_pd.Series( ["", "a", "A", "0", "1", "01", "123", "-1", "-12", "true", "True", "false", "False"] diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 5f7f38d99f7..fb7439059ed 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3979,6 +3979,7 @@ def groupby_apply( agg_args: Any, agg_kwargs: dict[str, Any], series_groupby: bool, + include_groups: bool, force_single_group: bool = False, force_list_like_to_series: bool = False, ) -> "SnowflakeQueryCompiler": @@ -4001,6 +4002,9 @@ def groupby_apply( Keyword arguments to pass to agg_func when applying it to each group. series_groupby: Whether we are performing a SeriesGroupBy.apply() instead of a DataFrameGroupBy.apply() + include_groups: + When True, will include grouping keys when calling func in the case that + they are columns of the DataFrame. force_single_group: Force single group (empty set of group by labels) useful for DataFrame.apply() with axis=0 force_list_like_to_series: @@ -4019,14 +4023,6 @@ def groupby_apply( + f"level={level}, and axis={axis}" ) - if "include_groups" in agg_kwargs: - # exclude "include_groups" from the apply function kwargs - include_groups = agg_kwargs.pop("include_groups") - if not include_groups: - ErrorMessage.not_implemented( - f"No support for groupby.apply with include_groups = {include_groups}" - ) - sort = groupby_kwargs.get("sort", True) as_index = groupby_kwargs.get("as_index", True) dropna = groupby_kwargs.get("dropna", True) @@ -4051,17 +4047,36 @@ def groupby_apply( ) snowflake_type_map = self._modin_frame.quoted_identifier_to_snowflake_type() - - # For DataFrameGroupBy, `func` operates on this frame in its entirety. - # For SeriesGroupBy, this frame may also include some grouping columns - # that `func` should not take as input. In that case, the only column - # that `func` takes as input is the last data column, so grab just that - # column with a slice starting at index -1 and ending at None. - input_data_column_identifiers = ( - self._modin_frame.data_column_snowflake_quoted_identifiers[ - slice(-1, None) if series_groupby else slice(None) - ] - ) + input_data_column_positions = [ + i + for i, identifier in enumerate( + self._modin_frame.data_column_snowflake_quoted_identifiers + ) + if ( + ( + # For SeriesGroupBy, this frame may also include some + # grouping columns that `func` should not take as input. In + # that case, the only column that `func` takes as input is + # the last data column, so take just that column. + # include_groups has no effect. + i + == len(self._modin_frame.data_column_snowflake_quoted_identifiers) + - 1 + ) + if series_groupby + else ( + # For DataFrameGroupBy, if include_groups, we apply the + # function to all data columns. Otherwise, we exclude + # data columns that we are grouping by. + include_groups + or identifier not in by_snowflake_quoted_identifiers_list + ) + ) + ] + input_data_column_identifiers = [ + self._modin_frame.data_column_snowflake_quoted_identifiers[i] + for i in input_data_column_positions + ] # TODO(SNOW-1210489): When type hints show that `agg_func` returns a # scalar, we can use a vUDF instead of a vUDTF and we can skip the @@ -4070,7 +4085,9 @@ def groupby_apply( agg_func, agg_args, agg_kwargs, - data_column_index=self._modin_frame.data_columns_index, + data_column_index=self._modin_frame.data_columns_index[ + input_data_column_positions + ], index_column_names=self._modin_frame.index_column_pandas_labels, input_data_column_types=[ snowflake_type_map[quoted_identifier] @@ -8511,6 +8528,7 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no series_groupby=True, force_single_group=True, force_list_like_to_series=True, + include_groups=True, ) data_col_result_frame = data_col_qc._modin_frame diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py index 3270b0832be..d05780d53df 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py @@ -1078,6 +1078,9 @@ def apply(): A callable that takes a dataframe or series as its first argument, and returns a dataframe, a series or a scalar. In addition the callable may take positional and keyword arguments. + include_groups : bool, default True + When True, will apply ``func`` to the groups in the case that they + are columns of the DataFrame. args, kwargs : tuple and dict Optional positional and keyword arguments to pass to ``func``. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py index add8e432df1..cc2df69a6e3 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/groupby_overrides.py @@ -238,7 +238,7 @@ def get_group(self, name, obj=None): # Function application ########################################################################### - def apply(self, func, *args, **kwargs): + def apply(self, func, *args, include_groups=True, **kwargs): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions # TODO: SNOW-1244717: Explore whether window function are performant and can be used # whenever `func` is an aggregation function. @@ -253,6 +253,7 @@ def apply(self, func, *args, **kwargs): agg_args=args, agg_kwargs=kwargs, series_groupby=False, + include_groups=include_groups, ) ) if dataframe_result.columns.equals(pandas.Index([MODIN_UNNAMED_SERIES_LABEL])): @@ -1445,7 +1446,7 @@ def get_group(self, name, obj=None): # Function application ########################################################################### - def apply(self, func, *args, **kwargs): + def apply(self, func, *args, include_groups=True, **kwargs): # TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.SeriesGroupBy functions if not callable(func): raise NotImplementedError("No support for non-callable `func`") @@ -1457,6 +1458,7 @@ def apply(self, func, *args, **kwargs): groupby_kwargs=self._kwargs, agg_args=args, agg_kwargs=kwargs, + include_groups=include_groups, # TODO(https://github.com/modin-project/modin/issues/7096): # upstream the series_groupby param to Modin series_groupby=True, diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a23d5c19324..05e2dcebe45 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -12,6 +12,7 @@ import re import sys import tempfile +import threading import warnings from array import array from functools import reduce @@ -104,8 +105,6 @@ TempObjectType, calculate_checksum, check_flatten_mode, - create_rlock, - create_thread_local, deprecated, escape_quotes, experimental, @@ -258,9 +257,6 @@ _PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND = ( "PYTHON_SNOWPARK_LARGE_QUERY_BREAKDOWN_COMPLEXITY_LOWER_BOUND" ) -_PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION = ( - "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION" -) # Flag for controlling the usage of scoped temp read only table. _PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE = ( "PYTHON_SNOWPARK_ENABLE_SCOPED_TEMP_READ_ONLY_TABLE" @@ -350,8 +346,6 @@ class Session: :class:`Session` contains functions to construct a :class:`DataFrame` like :meth:`table`, :meth:`sql` and :attr:`read`, etc. - - A :class:`Session` object is not thread-safe. """ class RuntimeConfig: @@ -639,19 +633,18 @@ def __init__( DEFAULT_COMPLEXITY_SCORE_UPPER_BOUND, ), ) - self._thread_store = create_thread_local( - self._conn._thread_safe_session_enabled - ) - self._lock = create_rlock(self._conn._thread_safe_session_enabled) + + self._thread_store = threading.local() + self._lock = RLock() # this lock is used to protect _packages. We use introduce a new lock because add_packages # launches a query to snowflake to get all version of packages available in snowflake. This # query can be slow and prevent other threads from moving on waiting for _lock. - self._package_lock = create_rlock(self._conn._thread_safe_session_enabled) + self._package_lock = RLock() # this lock is used to protect race-conditions when evaluating critical lazy properties # of SnowflakePlan or Selectable objects - self._plan_lock = create_rlock(self._conn._thread_safe_session_enabled) + self._plan_lock = RLock() self._custom_package_usage_config: Dict = {} self._conf = self.RuntimeConfig(self, options or {}) @@ -902,9 +895,7 @@ def custom_package_usage_config(self) -> Dict: @sql_simplifier_enabled.setter def sql_simplifier_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode( - "sql_simplifier_enabled", self._conn._thread_safe_session_enabled - ) + warn_session_config_update_in_multithreaded_mode("sql_simplifier_enabled") with self._lock: self._conn._telemetry_client.send_sql_simplifier_telemetry( @@ -921,9 +912,7 @@ def sql_simplifier_enabled(self, value: bool) -> None: @cte_optimization_enabled.setter @experimental_parameter(version="1.15.0") def cte_optimization_enabled(self, value: bool) -> None: - warn_session_config_update_in_multithreaded_mode( - "cte_optimization_enabled", self._conn._thread_safe_session_enabled - ) + warn_session_config_update_in_multithreaded_mode("cte_optimization_enabled") with self._lock: if value: @@ -937,8 +926,7 @@ def cte_optimization_enabled(self, value: bool) -> None: def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: """Set the value for eliminate_numeric_sql_value_cast_enabled""" warn_session_config_update_in_multithreaded_mode( - "eliminate_numeric_sql_value_cast_enabled", - self._conn._thread_safe_session_enabled, + "eliminate_numeric_sql_value_cast_enabled" ) if value in [True, False]: @@ -957,7 +945,7 @@ def eliminate_numeric_sql_value_cast_enabled(self, value: bool) -> None: def auto_clean_up_temp_table_enabled(self, value: bool) -> None: """Set the value for auto_clean_up_temp_table_enabled""" warn_session_config_update_in_multithreaded_mode( - "auto_clean_up_temp_table_enabled", self._conn._thread_safe_session_enabled + "auto_clean_up_temp_table_enabled" ) if value in [True, False]: @@ -980,7 +968,7 @@ def large_query_breakdown_enabled(self, value: bool) -> None: overall performance. """ warn_session_config_update_in_multithreaded_mode( - "large_query_breakdown_enabled", self._conn._thread_safe_session_enabled + "large_query_breakdown_enabled" ) if value in [True, False]: @@ -998,8 +986,7 @@ def large_query_breakdown_enabled(self, value: bool) -> None: def large_query_breakdown_complexity_bounds(self, value: Tuple[int, int]) -> None: """Set the lower and upper bounds for the complexity score used in large query breakdown optimization.""" warn_session_config_update_in_multithreaded_mode( - "large_query_breakdown_complexity_bounds", - self._conn._thread_safe_session_enabled, + "large_query_breakdown_complexity_bounds" ) if len(value) != 2: diff --git a/tests/conftest.py b/tests/conftest.py index b5fe956a014..576dd8e23b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -56,9 +56,6 @@ def is_excluded_frontend_file(path): def pytest_addoption(parser, pluginmanager): parser.addoption("--disable_sql_simplifier", action="store_true", default=False) parser.addoption("--disable_cte_optimization", action="store_true", default=False) - parser.addoption( - "--disable_multithreading_mode", action="store_true", default=False - ) parser.addoption("--skip_sql_count_check", action="store_true", default=False) if not any( "--local_testing_mode" in opt.names() for opt in parser._anonymous.options @@ -151,17 +148,6 @@ def proto_generated(): subprocess.check_call([sys.executable, "-m", "tox", "-e", "protoc"]) -MULTITHREADING_TEST_MODE_ENABLED = False - - -@pytest.fixture(scope="session", autouse=True) -def multithreading_mode_enabled(pytestconfig): - enabled = not pytestconfig.getoption("disable_multithreading_mode") - global MULTITHREADING_TEST_MODE_ENABLED - MULTITHREADING_TEST_MODE_ENABLED = enabled - return enabled - - @pytest.fixture(scope="session") def unparser_jar(pytestconfig): unparser_jar = pytestconfig.getoption("--unparser_jar") diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index e173eb52b8e..045e37af102 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -219,7 +219,6 @@ def session( sql_simplifier_enabled, local_testing_mode, cte_optimization_enabled, - multithreading_mode_enabled, ast_enabled, validate_ast, unparser_jar, @@ -234,10 +233,6 @@ def session( session = ( Session.builder.configs(db_parameters) .config("local_testing", local_testing_mode) - .config( - "session_parameters", - {"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": multithreading_mode_enabled}, - ) .create() ) session.sql_simplifier_enabled = sql_simplifier_enabled diff --git a/tests/integ/modin/groupby/test_groupby_apply.py b/tests/integ/modin/groupby/test_groupby_apply.py index c6c805a0ca3..c181cc46872 100644 --- a/tests/integ/modin/groupby/test_groupby_apply.py +++ b/tests/integ/modin/groupby/test_groupby_apply.py @@ -74,7 +74,7 @@ def transform_that_changes_columns(df: native_pd.DataFrame) -> native_pd.DataFra return native_pd.DataFrame( { "custom_sum": df["int_col"].cumsum() + df["int_col"].max(), - "custom_string": df["string_col_1"].astype("object").cumsum() + "custom_string": (df["string_col_2"].astype("object").cumsum()) + df["string_col_2"].str.cat(sep="-"), } ) @@ -165,25 +165,31 @@ def grouping_dfs_with_multiindexes() -> tuple[pd.DataFrame, native_pd.DataFrame] UDTF_COUNT = 1 +@pytest.mark.parametrize( + "include_groups", [True, False], ids=lambda v: f"include_groups_{v}" +) class TestFuncReturnsDataFrame: @pytest.mark.parametrize( "func", [ normalize_numeric_columns_by_sum, param( - lambda df: df.iloc[:, [2, 2, 0, 1]], + lambda df: df.iloc[:, [1, 0]], id="different_columns_but_same_index", ), param( lambda df: ( native_pd.DataFrame( - [["k0_grouped", 0, 1], ["k0_grouped", 2, 2]], + [ + list(range(0, len(df.columns))), + list(range(1, 1 + len(df.columns))), + ], index=native_pd.Index([None, 3], name="new_index"), columns=df.columns, ) - if df[("a", "string_col_1")].iloc[0] == "k0" + if df.index[0][0] == "i1" else native_pd.DataFrame( - [["other_key_grouped", 100, 101]], + [list(range(2, 2 + len(df.columns)))], index=native_pd.Index([None, 3], name="new_index"), columns=df.columns, ) @@ -199,13 +205,13 @@ class TestFuncReturnsDataFrame: join_count=JOIN_COUNT, ) def test_group_by_one_column_and_one_level_with_default_kwargs( - self, grouping_dfs_with_multiindexes, func + self, grouping_dfs_with_multiindexes, func, include_groups ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby( ["level_0", ("a", "string_col_1")], - ).apply(func), + ).apply(func, include_groups=include_groups), ) @sql_count_checker( @@ -213,12 +219,14 @@ def test_group_by_one_column_and_one_level_with_default_kwargs( udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_df_with_default_index(self, grouping_dfs_with_multiindexes): + def test_df_with_default_index( + self, grouping_dfs_with_multiindexes, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.reset_index(drop=True) .groupby(("a", "string_col_1")) - .apply(normalize_numeric_columns_by_sum), + .apply(normalize_numeric_columns_by_sum, include_groups=include_groups), ) @sql_count_checker( @@ -226,11 +234,12 @@ def test_df_with_default_index(self, grouping_dfs_with_multiindexes): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_func_returns_empty_frame(self): + def test_func_returns_empty_frame(self, include_groups): eval_snowpark_pandas_result( *create_test_dfs([[1, 2], [3, 4]]), lambda df: df.groupby(0).apply( - lambda df: native_pd.DataFrame(index=[1, 3], columns=[4, 5]) + lambda df: native_pd.DataFrame(index=[1, 3], columns=[4, 5]), + include_groups=include_groups, ), ) @@ -239,13 +248,15 @@ def test_func_returns_empty_frame(self): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_args_and_kwargs(self, grouping_dfs_with_multiindexes): + def test_args_and_kwargs(self, grouping_dfs_with_multiindexes, include_groups): def func(df, num1, str1): return df.applymap(lambda x: "_".join((str(x), num1, str1))) eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, - lambda df: df.groupby("level_0").apply(func, "0.3", str1="str1"), + lambda df: df.groupby("level_0").apply( + func, "0.3", str1="str1", include_groups=include_groups + ), ) @pytest.mark.parametrize( @@ -265,13 +276,17 @@ def func(df, num1, str1): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_group_by_level(self, grouping_dfs_with_multiindexes, level): + def test_group_by_level( + self, grouping_dfs_with_multiindexes, level, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, - lambda df: df.groupby(level=level).apply(lambda df: df.iloc[::-1, ::-1]), + lambda df: df.groupby(level=level).apply( + lambda df: df.iloc[::-1, ::-1], include_groups=include_groups + ), ) - def test_dropna_false(self, grouping_dfs_with_multiindexes): + def test_dropna_false(self, grouping_dfs_with_multiindexes, include_groups): snow_df, pandas_df = grouping_dfs_with_multiindexes # check that we are going to group by a column that has nulls. assert pandas_df[("a", "string_col_1")].isna().sum() > 0 @@ -280,7 +295,7 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: return df.groupby( ("a", "string_col_1"), dropna=False, - ).apply(normalize_numeric_columns_by_sum) + ).apply(normalize_numeric_columns_by_sum, include_groups=include_groups) with SqlCounter( # When dropna=False, we can skip the dropna query @@ -328,10 +343,14 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: np.nan, ], ) - def test_group_dataframe_with_column_of_all_nulls_snow_1233832(self, null_value): + def test_group_dataframe_with_column_of_all_nulls_snow_1233832( + self, null_value, include_groups + ): eval_snowpark_pandas_result( *create_test_dfs({"null_col": [null_value], "int_col": [1]}), - lambda df: df.groupby("int_col").apply(lambda x: x), + lambda df: df.groupby("int_col").apply( + lambda x: x, include_groups=include_groups + ), ) @sql_count_checker( @@ -346,11 +365,11 @@ def test_group_dataframe_with_column_of_all_nulls_snow_1233832(self, null_value) ("a", "string_col_1"), ], ) - def test_sort_false(self, grouping_dfs_with_multiindexes, by): + def test_sort_false(self, grouping_dfs_with_multiindexes, by, include_groups): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby(by, sort=False).apply( - normalize_numeric_columns_by_sum + normalize_numeric_columns_by_sum, include_groups=include_groups ), ) @@ -367,10 +386,14 @@ def test_sort_false(self, grouping_dfs_with_multiindexes, by): # behavior depends on whether the function is a transform. [normalize_numeric_columns_by_sum, duplicate_df_rowwise], ) - def test_as_index_false(self, grouping_dfs_with_multiindexes, by, func): + def test_as_index_false( + self, grouping_dfs_with_multiindexes, by, func, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, - lambda df: df.groupby(by=by, as_index=False).apply(func), + lambda df: df.groupby(by=by, as_index=False).apply( + func, include_groups=include_groups + ), ) @pytest.mark.parametrize( @@ -386,23 +409,25 @@ def test_as_index_false(self, grouping_dfs_with_multiindexes, by, func): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_group_keys_false(self, grouping_dfs_with_multiindexes, as_index): + def test_group_keys_false( + self, grouping_dfs_with_multiindexes, as_index, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby( by=["level_0", ("a", "string_col_1")], as_index=as_index, group_keys=False, - ).apply(normalize_numeric_columns_by_sum), + ).apply(normalize_numeric_columns_by_sum, include_groups=include_groups), ) @sql_count_checker(query_count=0) @pytest.mark.xfail(strict=True, raises=NotImplementedError) - def test_axis_one(self, grouping_dfs_with_multiindexes): + def test_axis_one(self, grouping_dfs_with_multiindexes, include_groups): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby(level=0, axis=1).apply( - normalize_numeric_columns_by_sum + normalize_numeric_columns_by_sum, include_groups=include_groups ), ) @@ -457,7 +482,7 @@ def test_axis_one(self, grouping_dfs_with_multiindexes): ], ) def test_df_with_single_level_labels( - self, by, as_index, func, group_keys, dfs_kwargs + self, by, as_index, func, group_keys, dfs_kwargs, include_groups ): mdf, pdf = create_test_dfs(**dfs_kwargs) @@ -466,7 +491,7 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: by=by, group_keys=group_keys, as_index=as_index, - ).apply(func) + ).apply(func, include_groups=include_groups) pandas_result = operation(pdf) with SqlCounter( @@ -503,7 +528,7 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: native_pd.DataFrame( { "custom_sum": [26, 30, 30, 46], - "custom_string": ["k0e", "k1d-b", "k0c", "k1k0d-b"], + "custom_string": ["ee", "dd-b", "cc", "dbd-b"], }, index=native_pd.Index( ["i0", "i1", "i2", "i1"], name="index" @@ -518,10 +543,10 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: { "custom_sum": [29, 28, 44, 60], "custom_string": [ - "k0e-c-b", - "k1d", - "k0k0e-c-b", - "k0k0k0e-c-b", + "ee-c-b", + "dd", + "ece-c-b", + "ecbe-c-b", ], }, index=native_pd.Index( @@ -542,14 +567,14 @@ def operation(df: native_pd.DataFrame) -> native_pd.DataFrame: udtf_count=UDTF_COUNT, ) def test_apply_transfform_to_subset( - self, grouping_dfs_with_multiindexes, set_sql_simplifier + self, grouping_dfs_with_multiindexes, set_sql_simplifier, include_groups ): """Test a bug where groupby.apply on a subset of columns was giving a syntax error only if sql simplifier was off.""" eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby("level_0", group_keys=False)[ [("b", "int_col"), ("b", "string_col_2")] - ].apply(normalize_numeric_columns_by_sum), + ].apply(normalize_numeric_columns_by_sum, include_groups=include_groups), ) @pytest.mark.parametrize( @@ -574,10 +599,14 @@ def test_apply_transfform_to_subset( join_count=JOIN_COUNT, udtf_count=UDTF_COUNT, ) - def test_numpy_ints_in_result(self, grouping_dfs_with_multiindexes, result): + def test_numpy_ints_in_result( + self, grouping_dfs_with_multiindexes, result, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, - lambda df: df.groupby(level=0).apply(lambda grp: result), + lambda df: df.groupby(level=0).apply( + lambda grp: result, include_groups=include_groups + ), ) @pytest.mark.xfail( @@ -585,13 +614,16 @@ def test_numpy_ints_in_result(self, grouping_dfs_with_multiindexes, result): raises=NotImplementedError, reason="No support for applying a function that returns two dataframes that have different labels for the column at a given position", ) - def test_mismatched_data_column_positions(self, grouping_dfs_with_multiindexes): + def test_mismatched_data_column_positions( + self, grouping_dfs_with_multiindexes, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby("level_0").apply( lambda df: native_pd.DataFrame([0], columns=["a"]) if df.iloc[0, 1] == 13 - else native_pd.DataFrame([1], columns=["b"]) + else native_pd.DataFrame([1], columns=["b"]), + include_groups=include_groups, ), ) @@ -600,7 +632,9 @@ def test_mismatched_data_column_positions(self, grouping_dfs_with_multiindexes): raises=NotImplementedError, reason="No support for applying a function that returns two dataframes that have different names for a given index level", ) - def test_mismatched_index_column_positions(self, grouping_dfs_with_multiindexes): + def test_mismatched_index_column_positions( + self, grouping_dfs_with_multiindexes, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby("level_0").apply( @@ -612,11 +646,12 @@ def test_mismatched_index_column_positions(self, grouping_dfs_with_multiindexes) else native_pd.DataFrame( [0], index=native_pd.Index([0], name="b"), - ) + ), + include_groups=include_groups, ), ) - def test_duplicate_index_groupby_mismatch_with_pandas(self): + def test_duplicate_index_groupby_mismatch_with_pandas(self, include_groups): # use a frame that has duplicates in its index to reproduce https://github.com/pandas-dev/pandas/issues/57906 # this bug is fixed in snowpark pandas but not in pandas. snow_df, pandas_df = create_test_dfs( @@ -636,7 +671,7 @@ def test_duplicate_index_groupby_mismatch_with_pandas(self): def groupby_apply_without_sort(df): return df.groupby("index", sort=False, dropna=False, group_keys=False)[ "int_col" - ].apply(lambda v: v) + ].apply(lambda v: v, include_groups=include_groups) # Assertion fails because index order is different due to pandas issue # 57906. @@ -667,6 +702,9 @@ def groupby_apply_without_sort(df): ) +@pytest.mark.parametrize( + "include_groups", [True, False], ids=lambda v: f"include_groups_{v}" +) class TestFuncReturnsScalar: @pytest.mark.parametrize("sort", [True, False], ids=lambda v: f"sort_{v}") @pytest.mark.parametrize("as_index", [True, False], ids=lambda v: f"as_index_{v}") @@ -679,7 +717,9 @@ class TestFuncReturnsScalar: udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_volume_from_brazil_per_year(self, sort, dropna, group_keys, as_index): + def test_volume_from_brazil_per_year( + self, sort, dropna, group_keys, as_index, include_groups + ): """Test an example that a user provided here: https://snowflake.slack.com/archives/C05RX90ETGU/p1707126781811689""" # TODO: group_keys should have no impact when func: df -> scalar # (normally it tells whether to include group keys in the index) @@ -708,7 +748,9 @@ def test_volume_from_brazil_per_year(self, sort, dropna, group_keys, as_index): group_keys=group_keys, dropna=dropna, ).apply( - lambda grp: grp[grp.country == "brazil"].volume.sum() / grp.volume.sum() + lambda grp: grp[grp.country == "brazil"].volume.sum() + / grp.volume.sum(), + include_groups=include_groups, ), ) @@ -717,7 +759,7 @@ def test_volume_from_brazil_per_year(self, sort, dropna, group_keys, as_index): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_root_mean_squared_error(self): + def test_root_mean_squared_error(self, include_groups): """Test an example that a user provided here: https://groups.google.com/a/snowflake.com/g/snowpark-pandas-api-customer-adoption-DL/c/0PDdj9-p5Hs/m/pRJ-I08dBAAJ""" eval_snowpark_pandas_result( *create_test_dfs( @@ -728,7 +770,8 @@ def test_root_mean_squared_error(self): } ), lambda df: df.groupby("customer_id").apply( - lambda grp: np.sqrt((grp.actual - grp.expected) ** 2).mean() + lambda grp: np.sqrt((grp.actual - grp.expected) ** 2).mean(), + include_groups=include_groups, ), ) @@ -742,14 +785,15 @@ def test_root_mean_squared_error(self): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_multiindex_df(self, grouping_dfs_with_multiindexes, by, sort, as_index): + def test_multiindex_df( + self, grouping_dfs_with_multiindexes, by, sort, as_index, include_groups + ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, - lambda df: df.groupby( - by, - sort=sort, - as_index=as_index, - ).apply(lambda df: df.astype(str).astype(object).sum().sum()), + lambda df: df.groupby(by, sort=sort, as_index=as_index,).apply( + lambda df: df.astype(str).astype(object).sum().sum(), + include_groups=include_groups, + ), ) @pytest.mark.parametrize( @@ -779,13 +823,15 @@ def test_multiindex_df(self, grouping_dfs_with_multiindexes, by, sort, as_index) join_count=JOIN_COUNT, ) def test_non_series_or_dataframe_return_types( - self, return_value, grouping_dfs_with_multiindexes + self, return_value, grouping_dfs_with_multiindexes, include_groups ): """These return types are scalars in the sense that they are not pandas Series or DataFrames.""" snow_df, pandas_df = grouping_dfs_with_multiindexes def operation(df): - return df.groupby(level=0).apply(lambda df: return_value) + return df.groupby(level=0).apply( + lambda df: return_value, include_groups=include_groups + ) if return_value is None: # this is a pandas bug: https://github.com/pandas-dev/pandas/issues/57775 @@ -808,7 +854,7 @@ def operation(df): udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_group_apply_return_df_from_lambda(self): + def test_group_apply_return_df_from_lambda(self, include_groups): diamonds_path = ( pathlib.Path(__file__).parent.parent.parent.parent / "resources" @@ -824,25 +870,17 @@ def test_group_apply_return_df_from_lambda(self): lambda x: x.sort_values( "price", ascending=False, kind="mergesort" ).head(5), - include_groups=True, + include_groups=include_groups, ), ) - with pytest.raises( - NotImplementedError, - match="No support for groupby.apply with include_groups = False", - ): - pd.DataFrame(diamonds_pd).groupby("cut").apply( - lambda x: x.sort_values("price", ascending=False).head(5), - include_groups=False, - ) - @pytest.mark.xfail(strict=True, raises=AssertionError, reason="SNOW-1619940") - def test_return_timedelta(self): + def test_return_timedelta(self, include_groups): eval_snowpark_pandas_result( *create_test_dfs([[1, 2]]), lambda df: df.groupby(0).apply( - lambda df: native_pd.Timedelta(df.sum().sum()) + lambda df: native_pd.Timedelta(df.sum().sum()), + include_groups=include_groups, ), ) @@ -863,12 +901,16 @@ def test_return_timedelta(self): ), ], ) - def test_timedelta_input(self, pandas_df): + def test_timedelta_input(self, pandas_df, include_groups): eval_snowpark_pandas_result( - *create_test_dfs(pandas_df), lambda df: df.groupby(0).apply(lambda df: 1) + *create_test_dfs(pandas_df), + lambda df: df.groupby(0).apply(lambda df: 1, include_groups=include_groups), ) +@pytest.mark.parametrize( + "include_groups", [True, False], ids=lambda v: f"include_groups_{v}" +) class TestFuncReturnsSeries: @pytest.mark.parametrize( "by,level", @@ -892,7 +934,14 @@ class TestFuncReturnsSeries: join_count=JOIN_COUNT, ) def test_return_series_with_two_columns( - self, grouping_dfs_with_multiindexes, by, level, as_index, sort, group_keys + self, + grouping_dfs_with_multiindexes, + by, + level, + as_index, + sort, + group_keys, + include_groups, ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, @@ -903,11 +952,12 @@ def test_return_series_with_two_columns( { "custom_sum": group[("b", "int_col")].sum() + group[("b", "int_col")].max(), - "custom_string": group[("a", "string_col_1")].str.cat(sep="-") + "custom_string": group[("b", "string_col_2")].str.cat(sep="-") + group[("b", "string_col_2")].str.cat(sep="_"), }, name="custom_metrics", - ) + ), + include_groups=include_groups, ), ) @@ -916,7 +966,7 @@ def test_return_series_with_two_columns( udtf_count=UDTF_COUNT, join_count=JOIN_COUNT, ) - def test_args_and_kwargs(self, grouping_dfs_with_multiindexes): + def test_args_and_kwargs(self, grouping_dfs_with_multiindexes, include_groups): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby(level=0).apply( @@ -932,6 +982,7 @@ def test_args_and_kwargs(self, grouping_dfs_with_multiindexes): ), 7, kwarg1="x", + include_groups=include_groups, ), ) @@ -943,7 +994,7 @@ def test_args_and_kwargs(self, grouping_dfs_with_multiindexes): join_count=JOIN_COUNT + 1, ) @pytest.mark.parametrize("index", [[2.0, np.nan, 2.0, 1.0], [np.nan] * 4]) - def test_dropna(self, dropna, index): + def test_dropna(self, dropna, index, include_groups): pandas_index = native_pd.Index(index, name="index") if dropna and pandas_index.isna().all(): pytest.xfail( @@ -974,7 +1025,8 @@ def test_dropna(self, dropna, index): + group["string_col_2"].str.cat(sep="_"), }, name="custom_metrics", - ) + ), + include_groups=include_groups, ), ) @@ -985,7 +1037,7 @@ def test_dropna(self, dropna, index): join_count=JOIN_COUNT, ) def test_returning_series_with_different_names( - self, grouping_dfs_with_multiindexes + self, grouping_dfs_with_multiindexes, include_groups ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, @@ -995,8 +1047,9 @@ def test_returning_series_with_different_names( "int_sum": group[("b", "int_col")].sum(), "string_sum": group[("b", "string_col_2")].astype(object).sum(), }, - name="name_" + group[("a", "string_col_1")].iloc[0], - ) + name="name_" + str(group.iloc[0, 0]), + ), + include_groups=include_groups, ), ) @@ -1007,22 +1060,21 @@ def test_returning_series_with_different_names( join_count=JOIN_COUNT, ) def test_returning_series_with_conflicting_indexes( - self, grouping_dfs_with_multiindexes + self, grouping_dfs_with_multiindexes, include_groups ): eval_snowpark_pandas_result( *grouping_dfs_with_multiindexes, lambda df: df.groupby(("a", "string_col_1")).apply( lambda group: native_pd.Series( { - # Since we are grouping by ("a", "string_col_1"), the - # series we return for each group will have a different index. - group[("a", "string_col_1")] - .iloc[0]: group[("b", "int_col")] - .sum(), - group[("a", "string_col_1")].iloc[0] - + "_2": group[("b", "string_col_2")].astype(object).sum(), + str(group[("b", "int_col")].iloc[0]): group[ + ("b", "int_col") + ].sum(), + str(group[("b", "int_col")].iloc[0]) + + "_2": group[("b", "int_col")].astype(object).sum(), }, - ) + ), + include_groups=include_groups, ), ) @@ -1040,9 +1092,14 @@ def test_returning_series_with_conflicting_indexes( @pytest.mark.parametrize("group_keys", [True, False], ids=lambda v: f"group_keys_{v}") @pytest.mark.parametrize("sort", [True, False], ids=lambda v: f"sort_{v}") @pytest.mark.parametrize("dropna", [True, False], ids=lambda v: f"dropna_{v}") +@pytest.mark.parametrize( + "include_groups", [True, False], ids=lambda v: f"include_groups_{v}" +) class TestSeriesGroupBy: @pytest.mark.parametrize("by", ["string_col_1", ["index", "string_col_1"], "index"]) - def test_dataframe_groupby_getitem(self, by, func, dropna, group_keys, sort): + def test_dataframe_groupby_getitem( + self, by, func, dropna, group_keys, sort, include_groups + ): """Test apply() on a SeriesGroupBy that we get by DataFrameGroupBy.__getitem__""" qc = ( QUERY_COUNT_WITH_TRANSFORM_CHECK @@ -1077,22 +1134,32 @@ def test_dataframe_groupby_getitem(self, by, func, dropna, group_keys, sort): ), ), lambda df: df.groupby( - by, dropna=dropna, group_keys=group_keys, sort=sort - )["int_col"].apply(func), + by, + dropna=dropna, + group_keys=group_keys, + sort=sort, + )["int_col"].apply(func, include_groups=include_groups), ) @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1238546") - def test_grouping_series_by_self(self, func, dropna, group_keys, sort): + def test_grouping_series_by_self( + self, func, dropna, group_keys, sort, include_groups + ): """Test apply() on a SeriesGroupBy that we get by grouping a series by itself.""" eval_snowpark_pandas_result( *create_test_series([0, 1, 2]), lambda s: s.groupby( - s, dropna=dropna, group_keys=group_keys, sort=sort - ).apply(func), + s, + dropna=dropna, + group_keys=group_keys, + sort=sort, + ).apply(func, include_groups=include_groups), ) @pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1238546") - def test_grouping_series_by_external_by(self, func, dropna, group_keys, sort): + def test_grouping_series_by_external_by( + self, func, dropna, group_keys, sort, include_groups + ): """Test apply() on a SeriesGroupBy that we get by grouping a series by its index.""" # This example is from pandas SeriesGroupBy apply docstring. eval_snowpark_pandas_result( @@ -1102,7 +1169,7 @@ def test_grouping_series_by_external_by(self, func, dropna, group_keys, sort): dropna=dropna, group_keys=group_keys, sort=sort, - ).apply(func), + ).apply(func, include_groups=include_groups), ) @@ -1227,3 +1294,21 @@ def test_scalar_then_series_then_dataframe(self): else native_pd.DataFrame([[2, 4], [5, 6]]) ), ) + + +@sql_count_checker( + query_count=QUERY_COUNT_WITHOUT_TRANSFORM_CHECK, + join_count=JOIN_COUNT, + udtf_count=UDTF_COUNT, +) +def test_include_groups_default_value(grouping_dfs_with_multiindexes): + """ + Test that the default value behavior include_groups matches pandas. + + We don't test the default include_groups value for all test cases + because that would substantially increase the size of the test suite. + """ + eval_snowpark_pandas_result( + *grouping_dfs_with_multiindexes, + lambda df: df.groupby(("a", "string_col_1")).apply(lambda df: df.count()), + ) diff --git a/tests/integ/modin/test_sql_counter.py b/tests/integ/modin/test_sql_counter.py index 07ccab4c30a..7ec2c0acedc 100644 --- a/tests/integ/modin/test_sql_counter.py +++ b/tests/integ/modin/test_sql_counter.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import threading import modin.pandas as pd import numpy as np import pandas as native_pd @@ -150,29 +151,48 @@ def test_high_sql_count_pass(): def test_sql_count_with_joins(): + thread_id = threading.get_ident() with SqlCounter(query_count=1, join_count=1) as sql_counter: sql_counter._notify( - QueryRecord(query_id="1", sql_text="SELECT A FROM X JOIN Y") + QueryRecord( + query_id="1", sql_text="SELECT A FROM X JOIN Y", thread_id=thread_id + ) ) with SqlCounter(query_count=1, join_count=2) as sql_counter: sql_counter._notify( - QueryRecord(query_id="1", sql_text="SELECT A FROM X JOIN Y JOIN Z") + QueryRecord( + query_id="1", + sql_text="SELECT A FROM X JOIN Y JOIN Z", + thread_id=thread_id, + ) ) with SqlCounter(query_count=2, join_count=5) as sql_counter: sql_counter._notify( - QueryRecord(query_id="1", sql_text="SELECT A FROM X JOIN Y JOIN Z") + QueryRecord( + query_id="1", + sql_text="SELECT A FROM X JOIN Y JOIN Z", + thread_id=thread_id, + ) ) sql_counter._notify( - QueryRecord(query_id="2", sql_text="SELECT A FROM X JOIN Y JOIN Z JOIN W") + QueryRecord( + query_id="2", + sql_text="SELECT A FROM X JOIN Y JOIN Z JOIN W", + thread_id=thread_id, + ) ) def test_sql_count_by_query_substr(): with SqlCounter(query_count=1) as sql_counter: sql_counter._notify( - QueryRecord(query_id="1", sql_text="SELECT A FROM X JOIN Y JOIN W") + QueryRecord( + query_id="1", + sql_text="SELECT A FROM X JOIN Y JOIN W", + thread_id=threading.get_ident(), + ) ) assert sql_counter._count_by_query_substr(contains=[" JOIN "]) == 1 @@ -193,7 +213,11 @@ def test_sql_count_by_query_substr(): def test_sql_count_instances_by_query_substr(): with SqlCounter(query_count=1) as sql_counter: sql_counter._notify( - QueryRecord(query_id="1", sql_text="SELECT A FROM X JOIN Y JOIN W") + QueryRecord( + query_id="1", + sql_text="SELECT A FROM X JOIN Y JOIN W", + thread_id=threading.get_ident(), + ) ) assert sql_counter._count_instances_by_query_substr(contains=[" JOIN "]) == 2 diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index a42f7cf9947..8922faeda5b 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -174,10 +174,7 @@ def examples(structured_type_support): def structured_type_session(session, structured_type_support): if structured_type_support: with structured_types_enabled_session(session) as sess: - with mock.patch( - "snowflake.snowpark.context._use_structured_type_semantics", True - ): - yield sess + yield sess else: yield session diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 7c91222181b..8e9ca2e8ced 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -4447,3 +4447,15 @@ def test_map_negative(session): output_types=[IntegerType(), StringType()], output_column_names=["a", "b", "c"], ) + + +def test_with_column_keep_column_order(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["A", "B"]) + df1 = df.with_column("A", lit(0), keep_column_order=True) + assert df1.columns == ["A", "B"] + df2 = df.with_columns(["A"], [lit(0)], keep_column_order=True) + assert df2.columns == ["A", "B"] + df3 = df.with_columns(["A", "C"], [lit(0), lit(0)], keep_column_order=True) + assert df3.columns == ["A", "B", "C"] + df3 = df.with_columns(["C", "A"], [lit(0), lit(0)], keep_column_order=True) + assert df3.columns == ["A", "B", "C"] diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 11c5d196df0..88d2bf58532 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -15,6 +15,7 @@ from snowflake.snowpark import Row from snowflake.snowpark.exceptions import SnowparkSQLException from snowflake.snowpark.functions import ( + _concat_ws_ignore_nulls, abs, array_agg, array_append, @@ -175,7 +176,12 @@ TimestampType, VariantType, ) -from tests.utils import TestData, Utils, running_on_jenkins +from tests.utils import ( + TestData, + Utils, + running_on_jenkins, + structured_types_enabled_session, +) def test_order(session): @@ -308,6 +314,62 @@ def test_concat_ws(session, col_a, col_b, col_c): assert res[0][0] == "1,2,3" +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="lambda function not supported", +) +@pytest.mark.parametrize("structured_type_semantics", [True, False]) +def test__concat_ws_ignore_nulls(session, structured_type_semantics): + data = [ + (["a", "b"], ["c"], "d", "e", 1, 2), # no nulls column + ( + ["Hello", None, "world"], + [None, "!", None], + "bye", + "world", + 3, + None, + ), # some nulls column + ([None, None], ["R", "H"], None, "TD", 4, 5), # some nulls column + (None, [None], None, None, None, None), # all nulls column + (None, None, None, None, None, None), # all nulls column + ] + cols = ["arr1", "arr2", "str1", "str2", "int1", "int2"] + + def check_concat_ws_ignore_nulls_output(session): + df = session.create_dataframe(data, schema=cols) + + # single character delimiter + Utils.check_answer( + df.select(_concat_ws_ignore_nulls(",", *cols)), + [ + Row("a,b,c,d,e,1,2"), + Row("Hello,world,!,bye,world,3"), + Row("R,H,TD,4,5"), + Row(""), + Row(""), + ], + ) + + # multi-character delimiter + Utils.check_answer( + df.select(_concat_ws_ignore_nulls(" : ", *cols)), + [ + Row("a : b : c : d : e : 1 : 2"), + Row("Hello : world : ! : bye : world : 3"), + Row("R : H : TD : 4 : 5"), + Row(""), + Row(""), + ], + ) + + if structured_type_semantics: + with structured_types_enabled_session(session) as session: + check_concat_ws_ignore_nulls_output(session) + else: + check_concat_ws_ignore_nulls_output(session) + + def test_concat_edge_cases(session): df = session.create_dataframe( [[None, 1, 2, 3], [4, None, 6, 7], [8, 9, None, 11], [12, 13, 14, None]] diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 4516c4aa047..349103a2239 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -20,10 +20,7 @@ PlanState, ) from snowflake.snowpark._internal.compiler.cte_utils import find_duplicate_subtrees -from snowflake.snowpark.session import ( - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION, - Session, -) +from snowflake.snowpark.session import Session from snowflake.snowpark.types import ( DoubleType, IntegerType, @@ -68,9 +65,6 @@ def threadsafe_session( else: new_db_parameters = db_parameters.copy() new_db_parameters["local_testing"] = local_testing_mode - new_db_parameters["session_parameters"] = { - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: True - } with Session.builder.configs(new_db_parameters).create() as session: session._sql_simplifier_enabled = sql_simplifier_enabled yield session @@ -888,36 +882,28 @@ def process_data(df_, thread_id): assert len(unique_drop_file_format_queries) == 10 -@pytest.mark.skipif( - IS_IN_STORED_PROC, reason="Cannot create new session inside stored proc" +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="cursor are not created in local testing mode", + run=False, ) -@pytest.mark.parametrize("is_enabled", [True, False]) -def test_num_cursors_created(db_parameters, is_enabled, local_testing_mode): - if is_enabled and local_testing_mode: - pytest.skip("Multithreading is enabled by default in local testing mode") - - num_workers = 5 if is_enabled else 1 - new_db_parameters = db_parameters.copy() - new_db_parameters["session_parameters"] = { - _PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION: is_enabled - } - - with Session.builder.configs(new_db_parameters).create() as new_session: +def test_num_cursors_created(threadsafe_session): + num_workers = 5 - def run_query(session_, thread_id): - assert session_.sql(f"SELECT {thread_id} as A").collect()[0][0] == thread_id + def run_query(session_, thread_id): + assert session_.sql(f"SELECT {thread_id} as A").collect()[0][0] == thread_id - with patch.object( - new_session._conn._telemetry_client, "send_cursor_created_telemetry" - ) as mock_telemetry: - with ThreadPoolExecutor(max_workers=num_workers) as executor: - for i in range(10): - executor.submit(run_query, new_session, i) + with patch.object( + threadsafe_session._conn._telemetry_client, "send_cursor_created_telemetry" + ) as mock_telemetry: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + for i in range(10): + executor.submit(run_query, threadsafe_session, i) # when multithreading is enabled, each worker will create a cursor # otherwise, we will use the same cursor created by the main thread # thus creating 0 new cursors. - assert mock_telemetry.call_count == (num_workers if is_enabled else 0) + assert mock_telemetry.call_count == num_workers @pytest.mark.xfail( diff --git a/tests/integ/utils/sql_counter.py b/tests/integ/utils/sql_counter.py index 2f1c5297b59..520d4bab7e7 100644 --- a/tests/integ/utils/sql_counter.py +++ b/tests/integ/utils/sql_counter.py @@ -95,6 +95,7 @@ sql_count_records = {} sql_counter_state = threading.local() +sql_counter_lock = threading.RLock() class SqlCounter(QueryListener): @@ -134,6 +135,9 @@ def __init__( self._queries: list[QueryRecord] = [] + # Track the thread id to ensure we only count queries from the current thread. + self._current_thread_id = threading.get_ident() + # Bypassing sql counter since # 1. it is an unnecessary metric for tests running in stored procedures # 2. pytest-assume package is not available in conda @@ -174,6 +178,10 @@ def __init__( def include_describe(self) -> bool: return True + @property + def include_thread_id(self) -> bool: + return True + @staticmethod def set_record_mode(record_mode): """Record mode means the SqlCounter does not assert any results, but rather collects them so they can @@ -202,7 +210,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _notify(self, query_record: QueryRecord, **kwargs: dict): if not is_suppress_sql_counter_listener(): - self._queries.append(query_record) + if query_record.thread_id == self._current_thread_id: + self._queries.append(query_record) def expects(self, **kwargs): """ @@ -653,16 +662,18 @@ def generate_sql_count_report(request, counter): def mark_sql_counter_called(): - threading.current_thread().__dict__[SQL_COUNTER_CALLED] = True + with sql_counter_lock: + threading.main_thread().__dict__[SQL_COUNTER_CALLED] = True def clear_sql_counter_called(): - threading.current_thread().__dict__[SQL_COUNTER_CALLED] = False + with sql_counter_lock: + threading.main_thread().__dict__[SQL_COUNTER_CALLED] = False def is_sql_counter_called(): - if SQL_COUNTER_CALLED in threading.current_thread().__dict__: - return threading.current_thread().__dict__.get(SQL_COUNTER_CALLED) + with sql_counter_lock: + return threading.main_thread().__dict__.get(SQL_COUNTER_CALLED, False) return False diff --git a/tests/mock/conftest.py b/tests/mock/conftest.py index b1addd0fd3e..61bf14462d6 100644 --- a/tests/mock/conftest.py +++ b/tests/mock/conftest.py @@ -10,13 +10,8 @@ @pytest.fixture(scope="function") -def mock_server_connection(multithreading_mode_enabled): - options = { - "session_parameters": { - "PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": multithreading_mode_enabled - } - } - s = MockServerConnection(options) +def mock_server_connection(): + s = MockServerConnection() yield s s.close() diff --git a/tests/mock/test_multithreading.py b/tests/mock/test_multithreading.py index e1720a2ea0e..8cd77e9c514 100644 --- a/tests/mock/test_multithreading.py +++ b/tests/mock/test_multithreading.py @@ -17,7 +17,6 @@ ) from snowflake.snowpark._internal.utils import normalize_local_file from snowflake.snowpark.functions import lit, when_matched -from snowflake.snowpark.mock._connection import MockServerConnection from snowflake.snowpark.mock._functions import MockedFunctionRegistry from snowflake.snowpark.mock._plan import MockExecutionPlan from snowflake.snowpark.mock._snowflake_data_type import TableEmulator @@ -28,19 +27,9 @@ from tests.utils import Utils -@pytest.fixture(scope="function", autouse=True) -def threadsafe_server_connection(): - options = { - "session_parameters": {"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION": True} - } - s = MockServerConnection(options) - yield s - s.close() - - @pytest.fixture(scope="function") -def threadsafe_session(threadsafe_server_connection): - with Session(threadsafe_server_connection) as s: +def threadsafe_session(mock_server_connection): + with Session(mock_server_connection) as s: yield s @@ -172,8 +161,8 @@ def test_mocked_function_registry_created_once(): @pytest.mark.parametrize("test_table", [True, False]) -def test_tabular_entity_registry(test_table, threadsafe_server_connection): - entity_registry = threadsafe_server_connection.entity_registry +def test_tabular_entity_registry(test_table, mock_server_connection): + entity_registry = mock_server_connection.entity_registry num_threads = 10 def write_read_and_drop_table(): @@ -210,8 +199,8 @@ def write_read_and_drop_view(): future.result() -def test_stage_entity_registry_put_and_get(threadsafe_server_connection): - stage_registry = StageEntityRegistry(threadsafe_server_connection) +def test_stage_entity_registry_put_and_get(mock_server_connection): + stage_registry = StageEntityRegistry(mock_server_connection) num_threads = 10 def put_and_get_file(): @@ -239,9 +228,9 @@ def put_and_get_file(): def test_stage_entity_registry_upload_and_read( - threadsafe_session, threadsafe_server_connection + threadsafe_session, mock_server_connection ): - stage_registry = StageEntityRegistry(threadsafe_server_connection) + stage_registry = StageEntityRegistry(mock_server_connection) num_threads = 10 def upload_and_read_json(thread_id: int): @@ -270,8 +259,8 @@ def upload_and_read_json(thread_id: int): future.result() -def test_stage_entity_registry_create_or_replace(threadsafe_server_connection): - stage_registry = StageEntityRegistry(threadsafe_server_connection) +def test_stage_entity_registry_create_or_replace(mock_server_connection): + stage_registry = StageEntityRegistry(mock_server_connection) num_threads = 10 with ThreadPoolExecutor(max_workers=num_threads) as executor: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 988590c52ce..f7aaa93e0f8 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -20,7 +20,6 @@ def mock_server_connection() -> ServerConnection: fake_snowflake_connection._conn = mock.MagicMock() fake_snowflake_connection._telemetry = None fake_snowflake_connection._session_parameters = {} - fake_snowflake_connection._thread_safe_session_enabled = True fake_snowflake_connection.cursor.return_value = mock.create_autospec( SnowflakeCursor ) @@ -34,7 +33,6 @@ def closed_mock_server_connection() -> ServerConnection: fake_snowflake_connection._conn = mock.MagicMock() fake_snowflake_connection._telemetry = None fake_snowflake_connection._session_parameters = {} - fake_snowflake_connection._thread_safe_session_enabled = True fake_snowflake_connection.is_closed = mock.MagicMock(return_value=False) fake_snowflake_connection.cursor.return_value = mock.create_autospec( SnowflakeCursor diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index 9bdd89493dc..1693dca258c 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -120,7 +120,6 @@ def nop(name): fake_session._cte_optimization_enabled = False fake_session._query_compilation_stage_enabled = False fake_session._conn = mock.create_autospec(ServerConnection) - fake_session._conn._thread_safe_session_enabled = False fake_session._plan_builder = SnowflakePlanBuilder(fake_session) fake_session._analyzer = Analyzer(fake_session) fake_session._use_scoped_temp_objects = True @@ -280,7 +279,6 @@ def test_same_joins_should_generate_same_queries(join_type, mock_server_connecti def test_statement_params(): mock_connection = mock.create_autospec(ServerConnection) mock_connection._conn = mock.MagicMock() - mock_connection._thread_safe_session_enabled = True session = snowflake.snowpark.session.Session(mock_connection) session._conn._telemetry_client = mock.MagicMock() df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) @@ -325,7 +323,6 @@ def test_session(): def test_table_source_plan(sql_simplifier_enabled): mock_connection = mock.create_autospec(ServerConnection) mock_connection._conn = mock.MagicMock() - mock_connection._thread_safe_session_enabled = True session = snowflake.snowpark.session.Session(mock_connection) session._sql_simplifier_enabled = sql_simplifier_enabled t = session.table("table") diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9414d610c5d..be008045d00 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -71,7 +71,6 @@ def test_get_active_session_when_no_active_sessions(): def test_used_scoped_temp_object(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() - fake_connection._thread_safe_session_enabled = True fake_connection._get_client_side_session_parameter = ( lambda x, y: ServerConnection._get_client_side_session_parameter( @@ -116,7 +115,6 @@ def test_used_scoped_temp_object(): def test_close_exception(): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() - fake_connection._thread_safe_session_enabled = True fake_connection._telemetry_client = mock.Mock() fake_connection.is_closed = MagicMock(return_value=False) exception_msg = "Mock exception for session.cancel_all" @@ -204,7 +202,6 @@ def mock_get_information_schema_packages(table_name: str): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock() - fake_connection._thread_safe_session_enabled = True fake_connection._get_current_parameter = mock_get_current_parameter session = Session(fake_connection) session.table = MagicMock(name="session.table") @@ -437,7 +434,6 @@ def test_parse_table_name(): def test_session_id(): fake_server_connection = mock.create_autospec(ServerConnection) - fake_server_connection._thread_safe_session_enabled = True fake_server_connection.get_session_id = mock.Mock(return_value=123456) session = Session(fake_server_connection) diff --git a/tests/utils.py b/tests/utils.py index 5ccbb7930c3..6086169d5da 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,11 +10,12 @@ import random import string import uuid -from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from typing import Dict, List, NamedTuple, Optional, Union +from threading import Thread +from unittest import mock import pytest import pytz @@ -121,7 +122,8 @@ def iceberg_supported(session, local_testing_mode): def structured_types_enabled_session(session): for param in STRUCTURED_TYPE_PARAMETERS: session.sql(f"alter session set {param}=true").collect() - yield session + with mock.patch("snowflake.snowpark.context._use_structured_type_semantics", True): + yield session for param in STRUCTURED_TYPE_PARAMETERS: session.sql(f"alter session unset {param}").collect() @@ -138,17 +140,17 @@ def running_on_jenkins() -> bool: def multithreaded_run(num_threads: int = 5) -> None: """When multithreading_mode is enabled, run the decorated test function in multiple threads.""" - from tests.conftest import MULTITHREADING_TEST_MODE_ENABLED def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - if MULTITHREADING_TEST_MODE_ENABLED: - with ThreadPoolExecutor(max_workers=num_threads) as executor: - for _ in range(num_threads): - executor.submit(func, *args, **kwargs) - else: - func(*args, **kwargs) + all_threads = [] + for _ in range(num_threads): + job = Thread(target=func, args=args, kwargs=kwargs) + all_threads.append(job) + job.start() + for thread in all_threads: + thread.join() return wrapper diff --git a/tox.ini b/tox.ini index 86c3736d246..60a6635d3a6 100644 --- a/tox.ini +++ b/tox.ini @@ -98,10 +98,8 @@ commands = notudf: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not udf" {posargs:} src/snowflake/snowpark tests udf: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} src/snowflake/snowpark tests notdoctest: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} tests - notmultithreaded: {env:SNOWFLAKE_PYTEST_CMD} --disable_multithreading_mode -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} tests notudfdoctest: {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} and not udf" {posargs:} tests local: {env:SNOWFLAKE_PYTEST_CMD} --local_testing_mode -m "integ or unit or mock" {posargs:} tests - localnotmultithreaded: {env:SNOWFLAKE_PYTEST_CMD} --disable_multithreading_mode --local_testing_mode -m "integ or unit or mock" {posargs:} tests dailynotdoctest: {env:SNOWFLAKE_PYTEST_DAILY_CMD} -m "{env:SNOWFLAKE_TEST_TYPE} or udf" {posargs:} tests # Snowpark pandas commands: snowparkpandasnotdoctest: {env:MODIN_PYTEST_CMD} --durations=20 -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} {env:SNOW_1314507_WORKAROUND_RERUN_FLAGS} tests/unit/modin tests/integ/modin tests/integ/test_df_to_snowpark_pandas.py