Skip to content

Commit

Permalink
Merge branch 'main' into lmukhopadhyay-SNOW-1874368-skip-cortex-doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jrose authored Jan 14, 2025
2 parents ac9afab + 8581f70 commit 2ea3bbd
Show file tree
Hide file tree
Showing 35 changed files with 606 additions and 476 deletions.
69 changes: 0 additions & 69 deletions .github/workflows/daily_precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -546,75 +546,6 @@ jobs:
.tox/.coverage
.tox/coverage.xml
test-snowpark-disable-multithreading-mode:
name: Test Snowpark Multithreading Disabled py-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
needs: build
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest-64-cores]
python-version: ["3.9"]
cloud-provider: [aws]
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Decrypt parameters.py
shell: bash
run: .github/scripts/decrypt_parameters.sh
env:
PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }}
CLOUD_PROVIDER: ${{ matrix.cloud-provider }}
- name: Install protoc
shell: bash
run: .github/scripts/install_protoc.sh
- name: Download wheel(s)
uses: actions/download-artifact@v4
with:
name: wheel
path: dist
- name: Show wheels downloaded
run: ls -lh dist
shell: bash
- name: Upgrade setuptools, pip and wheel
run: python -m pip install -U setuptools pip wheel
- name: Install tox
run: python -m pip install tox
- name: Run tests (excluding doctests)
run: python -m tox -e "py${PYTHON_VERSION/\./}-notmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Run local tests
run: python -m tox -e "py${PYTHON_VERSION/\./}-localnotmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Combine coverages
run: python -m tox -e coverage --skip-missing-interpreters false
shell: bash
env:
SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1
- uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage_${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}-snowpark-multithreading
path: |
.tox/.coverage
.tox/coverage.xml
combine-coverage:
if: ${{ success() || failure() }}
name: Combine coverage
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
#### Improvements

- Updated README.md to include instructions on how to verify package signatures using `cosign`.
- Added an option `keep_column_order` for keeping original column order in `DataFrame.with_column` and `DataFrame.with_columns`.

#### Bug Fixes

- Fixed a bug in local testing mode that caused a column to contain None when it should contain 0
- Fixed a bug in StructField.from_json that prevented TimestampTypes with tzinfo from being parsed correctly.
- Fixed a bug in `StructField.from_json` that prevented TimestampTypes with tzinfo from being parsed correctly.
- Fixed a bug in function `date_format` that caused an error when the input column was date type or timestamp type.

### Snowpark pandas API Updates

Expand All @@ -49,6 +51,7 @@
- %X: Locale’s appropriate time representation.
- %%: A literal '%' character.
- Added support for `Series.between`.
- Added support for `include_groups=False` in `DataFrameGroupBy.apply`.

#### Bug Fixes

Expand Down
4 changes: 2 additions & 2 deletions docs/source/modin/supported/groupby_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ Function application
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``apply`` | P | ``axis`` other than 0 is not | ``Y`` if the following are true, otherwise ``N``: |
| | | implemented. | - ``func`` is a callable that always returns |
| | | ``include_groups = False`` is | either a pandas DataFrame, a pandas Series, or |
| | | not implemented. | objects that are neither DataFrame nor Series. |
| | | | either a pandas DataFrame, a pandas Series, or |
| | | | objects that are neither DataFrame nor Series. |
| | | | - grouping on axis=0 |
| | | | - Not applying transform to a dataframe with a |
| | | | non-unique index |
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[build-system]
requires = [
"setuptools",
"protoc-wheel-0==21.1", # Protocol buffer compiler for Snowpark IR
"mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR
]
build-backend = "setuptools.build_meta"
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@
"graphviz", # used in plot tests
"pytest-assume", # sql counter check
"decorator", # sql counter check
"protoc-wheel-0==21.1", # Protocol buffer compiler, for Snowpark IR
"mypy-protobuf", # used in generating typed Python code from protobuf for Snowpark IR
"lxml", # used in read_xml tests
"tox", # used for setting up testing environments
]

# read the version
Expand Down
21 changes: 21 additions & 0 deletions src/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,27 @@ def pytest_runtest_makereport(item, call):
return tr


# These tests require python packages that are no longer built for python 3.8
PYTHON_38_SKIPS = {
"snowpark.session.Session.replicate_local_environment",
"snowpark.session.Session.table_function",
}

DocTestFinder = doctest.DocTestFinder


class CustomDocTestFinder(DocTestFinder):
def _find(self, tests, obj, name, module, source_lines, globs, seen):
if name in PYTHON_38_SKIPS and sys.version_info < (3, 9):
return
return DocTestFinder._find(
self, tests, obj, name, module, source_lines, globs, seen
)


doctest.DocTestFinder = CustomDocTestFinder


# scope is module so that we ensure we delete the session before
# moving onto running the tests in the tests dir. Having only one
# session is important to certain UDF tests to pass , since they
Expand Down
26 changes: 5 additions & 21 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
generate_random_alphanumeric,
get_copy_into_table_options,
is_sql_select_statement,
random_name_for_temp_object,
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import StructType
Expand Down Expand Up @@ -664,12 +663,8 @@ def large_local_relation_plan(
source_plan: Optional[LogicalPlan],
schema_query: Optional[str],
) -> SnowflakePlan:
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
temp_table_name = (
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
temp_table_name = f"temp_name_placeholder_{generate_random_alphanumeric()}"

attributes = [
Attribute(attr.name, attr.datatype, attr.nullable) for attr in output
]
Expand All @@ -696,9 +691,7 @@ def large_local_relation_plan(
Query(
create_table_stmt,
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
BatchInsertQuery(insert_stmt, data),
Query(select_stmt),
Expand Down Expand Up @@ -1215,7 +1208,6 @@ def read_file(
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
):
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
format_type_options, copy_options = get_copy_into_table_options(options)
format_type_options = self._merge_file_format_options(
format_type_options, options
Expand Down Expand Up @@ -1247,8 +1239,6 @@ def read_file(
post_queries: List[Query] = []
format_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
queries.append(
Query(
Expand All @@ -1262,9 +1252,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT),
)
)
post_queries.append(
Expand Down Expand Up @@ -1323,8 +1311,6 @@ def read_file(

temp_table_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
queries = [
Query(
Expand All @@ -1337,9 +1323,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
Query(
copy_into_table(
Expand Down
60 changes: 28 additions & 32 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,35 +204,31 @@ def replace_temp_obj_placeholders(
To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object
here.
"""
session = self._plan.session
if session._conn._thread_safe_session_enabled:
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries

return queries
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries
10 changes: 2 additions & 8 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
escape_quotes,
get_application_name,
get_version,
Expand Down Expand Up @@ -173,12 +171,8 @@ def __init__(
except TypeError:
pass

# thread safe param protection
self._thread_safe_session_enabled = self._get_client_side_session_parameter(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._thread_store = create_thread_local(self._thread_safe_session_enabled)
self._lock = threading.RLock()
self._thread_store = threading.local()

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
from threading import RLock
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
from snowflake.snowpark._internal.utils import create_rlock, is_in_stored_procedure
from snowflake.snowpark._internal.utils import is_in_stored_procedure

_logger = logging.getLogger(__name__)

Expand All @@ -34,7 +35,7 @@ def __init__(self, session: "Session") -> None:
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = create_rlock(session._conn._thread_safe_session_enabled)
self.lock = RLock()

def add(self, table: SnowflakeTable) -> None:
with self.lock:
Expand Down
Loading

0 comments on commit 2ea3bbd

Please sign in to comment.