Skip to content

Commit

Permalink
SNOW-1720855: clean up multithreading changes after rollout (#2658)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam authored Jan 14, 2025
1 parent 66dc14d commit 5b40dda
Show file tree
Hide file tree
Showing 20 changed files with 136 additions and 318 deletions.
69 changes: 0 additions & 69 deletions .github/workflows/daily_precommit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -546,75 +546,6 @@ jobs:
.tox/.coverage
.tox/coverage.xml
test-snowpark-disable-multithreading-mode:
name: Test Snowpark Multithreading Disabled py-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
needs: build
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest-64-cores]
python-version: ["3.9"]
cloud-provider: [aws]
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Decrypt parameters.py
shell: bash
run: .github/scripts/decrypt_parameters.sh
env:
PARAMETER_PASSWORD: ${{ secrets.PARAMETER_PASSWORD }}
CLOUD_PROVIDER: ${{ matrix.cloud-provider }}
- name: Install protoc
shell: bash
run: .github/scripts/install_protoc.sh
- name: Download wheel(s)
uses: actions/download-artifact@v4
with:
name: wheel
path: dist
- name: Show wheels downloaded
run: ls -lh dist
shell: bash
- name: Upgrade setuptools, pip and wheel
run: python -m pip install -U setuptools pip wheel
- name: Install tox
run: python -m pip install tox
- name: Run tests (excluding doctests)
run: python -m tox -e "py${PYTHON_VERSION/\./}-notmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Run local tests
run: python -m tox -e "py${PYTHON_VERSION/\./}-localnotmultithreaded-ci"
env:
PYTHON_VERSION: ${{ matrix.python-version }}
cloud_provider: ${{ matrix.cloud-provider }}
PYTEST_ADDOPTS: --color=yes --tb=short
TOX_PARALLEL_NO_SPINNER: 1
shell: bash
- name: Combine coverages
run: python -m tox -e coverage --skip-missing-interpreters false
shell: bash
env:
SNOWFLAKE_IS_PYTHON_RUNTIME_TEST: 1
- uses: actions/upload-artifact@v4
with:
include-hidden-files: true
name: coverage_${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}-snowpark-multithreading
path: |
.tox/.coverage
.tox/coverage.xml
combine-coverage:
if: ${{ success() || failure() }}
name: Combine coverage
Expand Down
26 changes: 5 additions & 21 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@
generate_random_alphanumeric,
get_copy_into_table_options,
is_sql_select_statement,
random_name_for_temp_object,
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import StructType
Expand Down Expand Up @@ -664,12 +663,8 @@ def large_local_relation_plan(
source_plan: Optional[LogicalPlan],
schema_query: Optional[str],
) -> SnowflakePlan:
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
temp_table_name = (
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
temp_table_name = f"temp_name_placeholder_{generate_random_alphanumeric()}"

attributes = [
Attribute(attr.name, attr.datatype, attr.nullable) for attr in output
]
Expand All @@ -696,9 +691,7 @@ def large_local_relation_plan(
Query(
create_table_stmt,
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
BatchInsertQuery(insert_stmt, data),
Query(select_stmt),
Expand Down Expand Up @@ -1215,7 +1208,6 @@ def read_file(
metadata_project: Optional[List[str]] = None,
metadata_schema: Optional[List[Attribute]] = None,
):
thread_safe_session_enabled = self.session._conn._thread_safe_session_enabled
format_type_options, copy_options = get_copy_into_table_options(options)
format_type_options = self._merge_file_format_options(
format_type_options, options
Expand Down Expand Up @@ -1247,8 +1239,6 @@ def read_file(
post_queries: List[Query] = []
format_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.FILE_FORMAT)
)
queries.append(
Query(
Expand All @@ -1262,9 +1252,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(format_name, TempObjectType.FILE_FORMAT),
)
)
post_queries.append(
Expand Down Expand Up @@ -1323,8 +1311,6 @@ def read_file(

temp_table_name = self.session.get_fully_qualified_name_if_possible(
f"temp_name_placeholder_{generate_random_alphanumeric()}"
if thread_safe_session_enabled
else random_name_for_temp_object(TempObjectType.TABLE)
)
queries = [
Query(
Expand All @@ -1337,9 +1323,7 @@ def read_file(
is_generated=True,
),
is_ddl_on_temp_object=True,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE)
if thread_safe_session_enabled
else None,
temp_obj_name_placeholder=(temp_table_name, TempObjectType.TABLE),
),
Query(
copy_into_table(
Expand Down
60 changes: 28 additions & 32 deletions src/snowflake/snowpark/_internal/compiler/plan_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,35 +204,31 @@ def replace_temp_obj_placeholders(
To prevent this, we generate queries with temp object name placeholders and replace them with actual temp object
here.
"""
session = self._plan.session
if session._conn._thread_safe_session_enabled:
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries

return queries
# This dictionary will store the mapping between placeholder name and actual temp object name.
placeholders = {}
# Final execution queries
execution_queries = {}
for query_type, query_list in queries.items():
execution_queries[query_type] = []
for query in query_list:
# If the query contains a temp object name placeholder, we generate a random
# name for the temp object and add it to the placeholders dictionary.
if query.temp_obj_name_placeholder:
(
placeholder_name,
temp_obj_type,
) = query.temp_obj_name_placeholder
placeholders[placeholder_name] = random_name_for_temp_object(
temp_obj_type
)

copied_query = copy.copy(query)
for placeholder_name, target_temp_name in placeholders.items():
# Copy the original query and replace all the placeholder names with the
# actual temp object names.
copied_query.sql = copied_query.sql.replace(
placeholder_name, target_temp_name
)

execution_queries[query_type].append(copied_query)
return execution_queries
10 changes: 2 additions & 8 deletions src/snowflake/snowpark/_internal/server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import TelemetryClient
from snowflake.snowpark._internal.utils import (
create_rlock,
create_thread_local,
escape_quotes,
get_application_name,
get_version,
Expand Down Expand Up @@ -173,12 +171,8 @@ def __init__(
except TypeError:
pass

# thread safe param protection
self._thread_safe_session_enabled = self._get_client_side_session_parameter(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._thread_store = create_thread_local(self._thread_safe_session_enabled)
self._lock = threading.RLock()
self._thread_store = threading.local()

if "password" in self._lower_case_parameters:
self._lower_case_parameters["password"] = None
Expand Down
5 changes: 3 additions & 2 deletions src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
import logging
from threading import RLock
import weakref
from collections import defaultdict
from typing import TYPE_CHECKING, Dict

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import SnowflakeTable
from snowflake.snowpark._internal.utils import create_rlock, is_in_stored_procedure
from snowflake.snowpark._internal.utils import is_in_stored_procedure

_logger = logging.getLogger(__name__)

Expand All @@ -34,7 +35,7 @@ def __init__(self, session: "Session") -> None:
# this dict will still be maintained even if the cleaner is stopped (`stop()` is called)
self.ref_count_map: Dict[str, int] = defaultdict(int)
# Lock to protect the ref_count_map
self.lock = create_rlock(session._conn._thread_safe_session_enabled)
self.lock = RLock()

def add(self, table: SnowflakeTable) -> None:
with self.lock:
Expand Down
48 changes: 1 addition & 47 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,7 @@ def normalize_path(path: str, is_local: bool) -> str:
return f"'{path}'"


def warn_session_config_update_in_multithreaded_mode(
config: str, thread_safe_mode_enabled: bool
) -> None:
if not thread_safe_mode_enabled:
return

def warn_session_config_update_in_multithreaded_mode(config: str) -> None:
if threading.active_count() > 1:
_logger.warning(
"You might have more than one threads sharing the Session object trying to update "
Expand Down Expand Up @@ -798,47 +793,6 @@ def warning(self, text: str) -> None:
self.count += 1


# TODO: SNOW-1720855: Remove DummyRLock and DummyThreadLocal after the rollout
class DummyRLock:
"""This is a dummy lock that is used in place of threading.Rlock when multithreading is
disabled."""

def __enter__(self):
pass

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def acquire(self, *args, **kwargs):
pass # pragma: no cover

def release(self, *args, **kwargs):
pass # pragma: no cover


class DummyThreadLocal:
"""This is a dummy thread local class that is used in place of threading.local when
multithreading is disabled."""

pass


def create_thread_local(
thread_safe_session_enabled: bool,
) -> Union[threading.local, DummyThreadLocal]:
if thread_safe_session_enabled:
return threading.local()
return DummyThreadLocal()


def create_rlock(
thread_safe_session_enabled: bool,
) -> Union[threading.RLock, DummyRLock]:
if thread_safe_session_enabled:
return threading.RLock()
return DummyRLock()


warning_dict: Dict[str, WarningHelper] = {}


Expand Down
8 changes: 2 additions & 6 deletions src/snowflake/snowpark/mock/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import json
import logging
import threading
import uuid
from copy import copy
from decimal import Decimal
Expand All @@ -30,7 +31,6 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.server_connection import DEFAULT_STRING_SIZE
from snowflake.snowpark._internal.utils import (
create_rlock,
is_in_stored_procedure,
result_set_to_rows,
)
Expand Down Expand Up @@ -297,11 +297,7 @@ def __init__(self, options: Optional[Dict[str, Any]] = None) -> None:
self._cursor = Mock()
self._options = options or {}
session_params = self._options.get("session_parameters", {})
# thread safe param protection
self._thread_safe_session_enabled = session_params.get(
"PYTHON_SNOWPARK_ENABLE_THREAD_SAFE_SESSION", False
)
self._lock = create_rlock(self._thread_safe_session_enabled)
self._lock = threading.RLock()
self._lower_case_parameters = {}
self._query_listeners = set()
self._telemetry_client = Mock()
Expand Down
Loading

0 comments on commit 5b40dda

Please sign in to comment.