diff --git a/python/pyspark/ml/torch/tests/test_distributor.py b/python/pyspark/ml/torch/tests/test_distributor.py index 34800710395ce..e9bf1d7840004 100644 --- a/python/pyspark/ml/torch/tests/test_distributor.py +++ b/python/pyspark/ml/torch/tests/test_distributor.py @@ -33,7 +33,8 @@ from pyspark.ml.torch.distributor import TorchDistributor, _get_gpus_owned from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive from pyspark.sql import SparkSession -from pyspark.testing.utils import SPARK_HOME, have_torch, torch_requirement_message +from pyspark.testing.sqlutils import SPARK_HOME +from pyspark.testing.utils import have_torch, torch_requirement_message @contextlib.contextmanager diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index 0b18212faf605..5c4be6570cd6a 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -185,7 +185,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: def _test() -> None: import os import sys - from pyspark.testing.utils import search_jar + from pyspark.testing.sqlutils import search_jar avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro") if avro_jar is None: diff --git a/python/pyspark/sql/connect/avro/functions.py b/python/pyspark/sql/connect/avro/functions.py index b26c29343d883..55067c33dd49b 100644 --- a/python/pyspark/sql/connect/avro/functions.py +++ b/python/pyspark/sql/connect/avro/functions.py @@ -94,7 +94,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: def _test() -> None: import os import sys - from pyspark.testing.utils import search_jar + from pyspark.testing.sqlutils import search_jar avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro") if avro_jar is None: diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py index ba43f94ce1eeb..ebe1f70fe8c7d 100644 --- a/python/pyspark/sql/connect/protobuf/functions.py +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -120,7 +120,7 @@ def _read_descriptor_set_file(filePath: str) -> bytes: def _test() -> None: import os import sys - from pyspark.testing.utils import search_jar + from pyspark.testing.sqlutils import search_jar protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf") if protobuf_jar is None: diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index ece450a77f4f3..7255ffb9b14b2 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -295,7 +295,7 @@ def _read_descriptor_set_file(filePath: str) -> bytes: def _test() -> None: import os import sys - from pyspark.testing.utils import search_jar + from pyspark.testing.sqlutils import search_jar protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf") if protobuf_jar is None: diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 0857591c306ae..4d54a502a536d 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -23,7 +23,7 @@ from pyspark.util import is_remote_only from pyspark.sql import SparkSession from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect -from pyspark.testing.utils import SPARK_HOME +from pyspark.testing.sqlutils import SPARK_HOME from pyspark.sql.functions import udf, assert_true, lit if should_test_connect: diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index e85877cc87e09..706b8c0a8be81 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -44,12 +44,12 @@ from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ( ReusedSQLTestCase, + SPARK_HOME, have_pyarrow, have_pandas, pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import SPARK_HOME class DataFrameTestsMixin: diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index a636b852a1e50..4d45c1c10a7de 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -36,8 +36,7 @@ pyarrow_requirement_message, ) from pyspark.testing import assertDataFrameEqual -from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.testing.utils import SPARK_HOME +from pyspark.testing.sqlutils import ReusedSQLTestCase, SPARK_HOME @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index e5464257422ae..4151dfd90459f 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import datetime import math import os @@ -32,6 +33,35 @@ have_pyarrow, pyarrow_requirement_message, ) +from pyspark.find_spark_home import _find_spark_home + + +SPARK_HOME = _find_spark_home() + + +def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): + # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can + # vary for SBT or Maven specifically. See also SPARK-26856 + project_full_path = os.path.join(SPARK_HOME, project_relative_path) + + # We should ignore the following jars + ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") + + # Search jar in the project dir using the jar name_prefix for both sbt build and maven + # build because the artifact jars are in different directories. + sbt_build = glob.glob( + os.path.join(project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix) + ) + maven_build = glob.glob(os.path.join(project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) + jar_paths = sbt_build + maven_build + jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] + + if not jars: + return None + elif len(jars) > 1: + raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) + else: + return jars[0] test_not_compiled_message = None diff --git a/python/pyspark/testing/streamingutils.py b/python/pyspark/testing/streamingutils.py index dba60b50fcc7b..61f0400513613 100644 --- a/python/pyspark/testing/streamingutils.py +++ b/python/pyspark/testing/streamingutils.py @@ -21,7 +21,7 @@ from pyspark import SparkConf, SparkContext, RDD from pyspark.streaming import StreamingContext -from pyspark.testing.utils import search_jar +from pyspark.testing.sqlutils import search_jar # Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 76f5b48ff9bb0..e4b2f891f34b5 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -15,14 +15,12 @@ # limitations under the License. # -import glob import os import struct import sys import unittest import difflib import functools -import math from decimal import Decimal from time import time, sleep from typing import ( @@ -37,9 +35,7 @@ from pyspark import SparkConf from pyspark.errors import PySparkAssertionError, PySparkException, PySparkTypeError -from pyspark.errors.exceptions.captured import CapturedException from pyspark.errors.exceptions.base import QueryContextType -from pyspark.find_spark_home import _find_spark_home from pyspark.sql.dataframe import DataFrame from pyspark.sql import Row from pyspark.sql.types import StructType, StructField, VariantVal @@ -48,8 +44,6 @@ __all__ = ["assertDataFrameEqual", "assertSchemaEqual"] -SPARK_HOME = _find_spark_home() - def have_package(name: str) -> bool: import importlib @@ -259,31 +253,6 @@ def close(self): pass -def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix): - # Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can - # vary for SBT or Maven specifically. See also SPARK-26856 - project_full_path = os.path.join(SPARK_HOME, project_relative_path) - - # We should ignore the following jars - ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar") - - # Search jar in the project dir using the jar name_prefix for both sbt build and maven - # build because the artifact jars are in different directories. - sbt_build = glob.glob( - os.path.join(project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix) - ) - maven_build = glob.glob(os.path.join(project_full_path, "target/%s*.jar" % mvn_jar_name_prefix)) - jar_paths = sbt_build + maven_build - jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)] - - if not jars: - return None - elif len(jars) > 1: - raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars))) - else: - return jars[0] - - def _terminal_color_support(): try: # determine if environment supports color diff --git a/python/pyspark/tests/test_appsubmit.py b/python/pyspark/tests/test_appsubmit.py index 79b6b4fa91a75..5f2c8b49d279d 100644 --- a/python/pyspark/tests/test_appsubmit.py +++ b/python/pyspark/tests/test_appsubmit.py @@ -23,7 +23,7 @@ import unittest import zipfile -from pyspark.testing.utils import SPARK_HOME +from pyspark.testing.sqlutils import SPARK_HOME class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 0a9628977af97..c1cd361b7f999 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -24,7 +24,8 @@ from collections import namedtuple from pyspark import SparkConf, SparkFiles, SparkContext -from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME +from pyspark.testing.sqlutils import SPARK_HOME +from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest class CheckpointTests(ReusedPySparkTestCase): diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index 752b5d5599cab..c2d41959d3002 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -36,8 +36,8 @@ NoOpSerializer, ) from pyspark.sql import SparkSession -from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest, have_numpy -from pyspark.testing.sqlutils import have_pandas +from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest, have_numpy +from pyspark.testing.sqlutils import SPARK_HOME, have_pandas global_func = lambda: "Hi" # noqa: E731 diff --git a/python/pyspark/tests/test_readwrite.py b/python/pyspark/tests/test_readwrite.py index 73f1025635cbd..7ff5007d4f3f3 100644 --- a/python/pyspark/tests/test_readwrite.py +++ b/python/pyspark/tests/test_readwrite.py @@ -19,7 +19,8 @@ import tempfile import unittest -from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME +from pyspark.testing.sqlutils import SPARK_HOME +from pyspark.testing.utils import ReusedPySparkTestCase class InputFormatTests(ReusedPySparkTestCase): diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 50acc3bab07e8..083d89eaedc7b 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -23,7 +23,8 @@ import unittest from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext -from pyspark.testing.utils import PySparkTestCase, SPARK_HOME, eventually +from pyspark.testing.sqlutils import SPARK_HOME +from pyspark.testing.utils import PySparkTestCase, eventually class TaskContextTests(PySparkTestCase):