Skip to content

Commit

Permalink
[SPARK-50867][PYTHON][TESTS] Move search_jar and SPARK_HOME to `s…
Browse files Browse the repository at this point in the history
…qlutils.py`

### What changes were proposed in this pull request?
Move `search_jar` and `SPARK_HOME` to `sqlutils.py`

### Why are the changes needed?
`search_jar` and `SPARK_HOME` are only for test, but `pyspark.testing.utils` is [exposed](https://apache.github.io/spark/api/python/reference/pyspark.testing.html) to end users.

When importing `pyspark.testing.utils`, `SPARK_HOME = _find_spark_home()` may fail and cause import failure.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#49543 from zhengruifeng/py_find_home.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed Jan 17, 2025
1 parent c81fa33 commit 3a12038
Show file tree
Hide file tree
Showing 16 changed files with 49 additions and 47 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/ml/torch/tests/test_distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from pyspark.ml.torch.distributor import TorchDistributor, _get_gpus_owned
from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate, check_parent_alive
from pyspark.sql import SparkSession
from pyspark.testing.utils import SPARK_HOME, have_torch, torch_requirement_message
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import have_torch, torch_requirement_message


@contextlib.contextmanager
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
def _test() -> None:
import os
import sys
from pyspark.testing.utils import search_jar
from pyspark.testing.sqlutils import search_jar

avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro")
if avro_jar is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
def _test() -> None:
import os
import sys
from pyspark.testing.utils import search_jar
from pyspark.testing.sqlutils import search_jar

avro_jar = search_jar("connector/avro", "spark-avro", "spark-avro")
if avro_jar is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/protobuf/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _read_descriptor_set_file(filePath: str) -> bytes:
def _test() -> None:
import os
import sys
from pyspark.testing.utils import search_jar
from pyspark.testing.sqlutils import search_jar

protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf")
if protobuf_jar is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/protobuf/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _read_descriptor_set_file(filePath: str) -> bytes:
def _test() -> None:
import os
import sys
from pyspark.testing.utils import search_jar
from pyspark.testing.sqlutils import search_jar

protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf")
if protobuf_jar is None:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/client/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyspark.util import is_remote_only
from pyspark.sql import SparkSession
from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
from pyspark.testing.utils import SPARK_HOME
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.sql.functions import udf, assert_true, lit

if should_test_connect:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
SPARK_HOME,
have_pyarrow,
have_pandas,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import SPARK_HOME


class DataFrameTestsMixin:
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
pyarrow_requirement_message,
)
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import SPARK_HOME
from pyspark.testing.sqlutils import ReusedSQLTestCase, SPARK_HOME


@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import glob
import datetime
import math
import os
Expand All @@ -32,6 +33,35 @@
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.find_spark_home import _find_spark_home


SPARK_HOME = _find_spark_home()


def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
# Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can
# vary for SBT or Maven specifically. See also SPARK-26856
project_full_path = os.path.join(SPARK_HOME, project_relative_path)

# We should ignore the following jars
ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")

# Search jar in the project dir using the jar name_prefix for both sbt build and maven
# build because the artifact jars are in different directories.
sbt_build = glob.glob(
os.path.join(project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix)
)
maven_build = glob.glob(os.path.join(project_full_path, "target/%s*.jar" % mvn_jar_name_prefix))
jar_paths = sbt_build + maven_build
jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)]

if not jars:
return None
elif len(jars) > 1:
raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
else:
return jars[0]


test_not_compiled_message = None
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/testing/streamingutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from pyspark import SparkConf, SparkContext, RDD
from pyspark.streaming import StreamingContext
from pyspark.testing.utils import search_jar
from pyspark.testing.sqlutils import search_jar


# Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py
Expand Down
31 changes: 0 additions & 31 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@
# limitations under the License.
#

import glob
import os
import struct
import sys
import unittest
import difflib
import functools
import math
from decimal import Decimal
from time import time, sleep
from typing import (
Expand All @@ -37,9 +35,7 @@

from pyspark import SparkConf
from pyspark.errors import PySparkAssertionError, PySparkException, PySparkTypeError
from pyspark.errors.exceptions.captured import CapturedException
from pyspark.errors.exceptions.base import QueryContextType
from pyspark.find_spark_home import _find_spark_home
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, VariantVal
Expand All @@ -48,8 +44,6 @@

__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]

SPARK_HOME = _find_spark_home()


def have_package(name: str) -> bool:
import importlib
Expand Down Expand Up @@ -259,31 +253,6 @@ def close(self):
pass


def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
# Note that 'sbt_jar_name_prefix' and 'mvn_jar_name_prefix' are used since the prefix can
# vary for SBT or Maven specifically. See also SPARK-26856
project_full_path = os.path.join(SPARK_HOME, project_relative_path)

# We should ignore the following jars
ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")

# Search jar in the project dir using the jar name_prefix for both sbt build and maven
# build because the artifact jars are in different directories.
sbt_build = glob.glob(
os.path.join(project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix)
)
maven_build = glob.glob(os.path.join(project_full_path, "target/%s*.jar" % mvn_jar_name_prefix))
jar_paths = sbt_build + maven_build
jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)]

if not jars:
return None
elif len(jars) > 1:
raise RuntimeError("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
else:
return jars[0]


def _terminal_color_support():
try:
# determine if environment supports color
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/tests/test_appsubmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import unittest
import zipfile

from pyspark.testing.utils import SPARK_HOME
from pyspark.testing.sqlutils import SPARK_HOME


class SparkSubmitTests(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from collections import namedtuple

from pyspark import SparkConf, SparkFiles, SparkContext
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest, SPARK_HOME
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest


class CheckpointTests(ReusedPySparkTestCase):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
NoOpSerializer,
)
from pyspark.sql import SparkSession
from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest, have_numpy
from pyspark.testing.sqlutils import have_pandas
from pyspark.testing.utils import ReusedPySparkTestCase, QuietTest, have_numpy
from pyspark.testing.sqlutils import SPARK_HOME, have_pandas


global_func = lambda: "Hi" # noqa: E731
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import tempfile
import unittest

from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import ReusedPySparkTestCase


class InputFormatTests(ReusedPySparkTestCase):
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/tests/test_taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import unittest

from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
from pyspark.testing.utils import PySparkTestCase, SPARK_HOME, eventually
from pyspark.testing.sqlutils import SPARK_HOME
from pyspark.testing.utils import PySparkTestCase, eventually


class TaskContextTests(PySparkTestCase):
Expand Down

0 comments on commit 3a12038

Please sign in to comment.