From 6fdf9c9df545ed50acbce1ec874625baf03d4d2e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 11 Apr 2024 17:10:29 +0900 Subject: [PATCH] [SPARK-47807][PYTHON][ML] Make pyspark.ml compatible with pyspark-connect ### What changes were proposed in this pull request? This PR proposes to make `pyspark.ml` compatible with `pyspark-connect`. ### Why are the changes needed? In order for `pyspark-connect` to work without classic PySpark packages and dependencies. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Yes, at https://github.com/apache/spark/pull/45941. Once CI is setup there, it will be tested there properly. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45995 from HyukjinKwon/SPARK-47807. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/ml/classification.py | 11 +++++-- python/pyspark/ml/clustering.py | 3 +- python/pyspark/ml/common.py | 45 ++++++++++++++++++++--------- python/pyspark/ml/feature.py | 8 ++++- python/pyspark/ml/functions.py | 5 +++- python/pyspark/ml/image.py | 6 +++- python/pyspark/ml/pipeline.py | 9 ++++-- python/pyspark/ml/stat.py | 10 ++++++- python/pyspark/ml/tuning.py | 17 +++++++++-- python/pyspark/ml/util.py | 15 ++++++---- python/pyspark/ml/wrapper.py | 17 ++++++++++- 11 files changed, 111 insertions(+), 35 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 38ccba560236e..1eb42f8029b6c 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -37,7 +37,7 @@ TYPE_CHECKING, ) -from pyspark import keyword_only, since, SparkContext, inheritable_thread_target +from pyspark import keyword_only, since, inheritable_thread_target from pyspark.ml import Estimator, Predictor, PredictionModel, Model from pyspark.ml.param.shared import ( HasRawPredictionCol, @@ -97,6 +97,7 @@ if TYPE_CHECKING: from pyspark.ml._typing import P, ParamMap from py4j.java_gateway import JavaObject + from pyspark.core.context import SparkContext T = TypeVar("T") @@ -3677,7 +3678,7 @@ class _OneVsRestSharedReadWrite: @staticmethod def saveImpl( instance: Union[OneVsRest, "OneVsRestModel"], - sc: SparkContext, + sc: "SparkContext", path: str, extraMetadata: Optional[Dict[str, Any]] = None, ) -> None: @@ -3690,7 +3691,7 @@ def saveImpl( cast(MLWritable, instance.getClassifier()).save(classifierPath) @staticmethod - def loadClassifier(path: str, sc: SparkContext) -> Union[OneVsRest, "OneVsRestModel"]: + def loadClassifier(path: str, sc: "SparkContext") -> Union[OneVsRest, "OneVsRestModel"]: classifierPath = os.path.join(path, "classifier") return DefaultParamsReader.loadParamsInstance(classifierPath, sc) @@ -3771,6 +3772,8 @@ def setRawPredictionCol(self, value: str) -> "OneVsRestModel": def __init__(self, models: List[ClassificationModel]): super(OneVsRestModel, self).__init__() + from pyspark.core.context import SparkContext + self.models = models if not isinstance(models[0], JavaMLWritable): return @@ -3913,6 +3916,8 @@ def _to_java(self) -> "JavaObject": py4j.java_gateway.JavaObject Java object equivalent to this instance. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._gateway is not None diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7f9e87e612432..952c994c62ca4 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -44,7 +44,6 @@ JavaMLReadable, GeneralJavaMLWritable, HasTrainingSummary, - SparkContext, ) from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper from pyspark.ml.common import inherit_doc, _java2py @@ -226,6 +225,8 @@ def gaussians(self) -> List[MultivariateGaussian]: Array of :py:class:`MultivariateGaussian` where gaussians[i] represents the Multivariate Gaussian (Normal) Distribution for Gaussian i """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and self._java_obj is not None diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index d9d7c9b9c21cf..1ae15fdf547eb 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -17,21 +17,25 @@ from typing import Any, Callable, TYPE_CHECKING -import py4j.protocol -from py4j.protocol import Py4JJavaError -from py4j.java_gateway import JavaObject -from py4j.java_collections import JavaArray, JavaList - -import pyspark.core.context -from pyspark import RDD, SparkContext +from pyspark.util import is_remote_only from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession if TYPE_CHECKING: + import py4j.protocol + from py4j.java_gateway import JavaObject + + import pyspark.core.context + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext from pyspark.ml._typing import C, JavaObjectOrPickleDump -# Hack for support float('inf') in Py4j -_old_smart_decode = py4j.protocol.smart_decode + +if not is_remote_only(): + import py4j + + # Hack for support float('inf') in Py4j + _old_smart_decode = py4j.protocol.smart_decode _float_str_mapping = { "nan": "NaN", @@ -47,7 +51,10 @@ def _new_smart_decode(obj: Any) -> str: return _old_smart_decode(obj) -py4j.protocol.smart_decode = _new_smart_decode +if not is_remote_only(): + import py4j + + py4j.protocol.smart_decode = _new_smart_decode _picklable_classes = [ @@ -59,7 +66,7 @@ def _new_smart_decode(obj: Any) -> str: # this will call the ML version of pythonToJava() -def _to_java_object_rdd(rdd: RDD) -> JavaObject: +def _to_java_object_rdd(rdd: "RDD") -> "JavaObject": """Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the @@ -70,8 +77,12 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject: return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) -def _py2java(sc: SparkContext, obj: Any) -> JavaObject: +def _py2java(sc: "SparkContext", obj: Any) -> "JavaObject": """Convert Python object into Java""" + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext + if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) elif isinstance(obj, DataFrame): @@ -91,7 +102,11 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject: return obj -def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: +def _java2py(sc: "SparkContext", r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: + from py4j.protocol import Py4JJavaError + from py4j.java_gateway import JavaObject + from py4j.java_collections import JavaArray, JavaList + if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -122,7 +137,9 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt def callJavaFunc( - sc: pyspark.core.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any + sc: "pyspark.core.context.SparkContext", + func: Callable[..., "JavaObjectOrPickleDump"], + *args: Any, ) -> "JavaObjectOrPickleDump": """Call Java Function""" java_args = [_py2java(sc, a) for a in args] diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 349b50913d7db..9a392c9dd420f 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -28,7 +28,7 @@ TYPE_CHECKING, ) -from pyspark import keyword_only, since, SparkContext +from pyspark import keyword_only, since from pyspark.ml.linalg import _convert_to_vector, DenseMatrix, DenseVector, Vector from pyspark.sql.dataframe import DataFrame from pyspark.ml.param.shared import ( @@ -1202,6 +1202,8 @@ def from_vocabulary( Construct the model directly from a vocabulary list of strings, requires an active SparkContext. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._gateway is not None java_class = sc._gateway.jvm.java.lang.String @@ -4791,6 +4793,8 @@ def from_labels( Construct the model directly from an array of label strings, requires an active SparkContext. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._gateway is not None java_class = sc._gateway.jvm.java.lang.String @@ -4818,6 +4822,8 @@ def from_arrays_of_labels( Construct the model directly from an array of array of label strings, requires an active SparkContext. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._gateway is not None java_class = sc._gateway.jvm.java.lang.String diff --git a/python/pyspark/ml/functions.py b/python/pyspark/ml/functions.py index 55631a818bb9b..466d94ccc8889 100644 --- a/python/pyspark/ml/functions.py +++ b/python/pyspark/ml/functions.py @@ -27,7 +27,6 @@ except ImportError: pass # Let it throw a better error message later when the API is invoked. -from pyspark import SparkContext from pyspark.sql.functions import pandas_udf from pyspark.sql.column import Column, _to_java_column from pyspark.sql.types import ( @@ -116,6 +115,8 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column: [StructField('vec', ArrayType(FloatType(), False), False), StructField('oldVec', ArrayType(FloatType(), False), False)] """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None return Column( @@ -157,6 +158,8 @@ def array_to_vector(col: Column) -> Column: >>> df3.select(array_to_vector('v1').alias('vec1')).collect() [Row(vec1=DenseVector([1.0, 3.0]))] """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and sc._jvm is not None return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col))) diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py index 329a56459e6c8..d0223739ffdf8 100644 --- a/python/pyspark/ml/image.py +++ b/python/pyspark/ml/image.py @@ -29,7 +29,6 @@ import numpy as np -from pyspark import SparkContext from pyspark.sql.types import Row, StructType, _create_row, _parse_datatype_json_string from pyspark.sql import SparkSession @@ -63,6 +62,7 @@ def imageSchema(self) -> StructType: .. versionadded:: 2.3.0 """ + from pyspark.core.context import SparkContext if self._imageSchema is None: ctx = SparkContext._active_spark_context @@ -83,6 +83,7 @@ def ocvTypes(self) -> Dict[str, int]: .. versionadded:: 2.3.0 """ + from pyspark.core.context import SparkContext if self._ocvTypes is None: ctx = SparkContext._active_spark_context @@ -103,6 +104,7 @@ def columnSchema(self) -> StructType: .. versionadded:: 2.4.0 """ + from pyspark.core.context import SparkContext if self._columnSchema is None: ctx = SparkContext._active_spark_context @@ -123,6 +125,7 @@ def imageFields(self) -> List[str]: .. versionadded:: 2.3.0 """ + from pyspark.core.context import SparkContext if self._imageFields is None: ctx = SparkContext._active_spark_context @@ -137,6 +140,7 @@ def undefinedImageType(self) -> str: .. versionadded:: 2.3.0 """ + from pyspark.core.context import SparkContext if self._undefinedImageType is None: ctx = SparkContext._active_spark_context diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 24653d1d919ee..c8415f89670b7 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -18,7 +18,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, TYPE_CHECKING -from pyspark import keyword_only, since, SparkContext +from pyspark import keyword_only, since from pyspark.ml.base import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import ( @@ -40,6 +40,7 @@ if TYPE_CHECKING: from pyspark.ml._typing import ParamMap, PipelineStage from py4j.java_gateway import JavaObject + from pyspark.core.context import SparkContext @inherit_doc @@ -200,6 +201,7 @@ def _to_java(self) -> "JavaObject": py4j.java_gateway.JavaObject Java object equivalent to this instance. """ + from pyspark.core.context import SparkContext gateway = SparkContext._gateway assert gateway is not None and SparkContext._jvm is not None @@ -353,6 +355,7 @@ def _to_java(self) -> "JavaObject": :return: Java object equivalent to this instance. """ + from pyspark.core.context import SparkContext gateway = SparkContext._gateway assert gateway is not None and SparkContext._jvm is not None @@ -400,7 +403,7 @@ def validateStages(stages: List["PipelineStage"]) -> None: def saveImpl( instance: Union[Pipeline, PipelineModel], stages: List["PipelineStage"], - sc: SparkContext, + sc: "SparkContext", path: str, ) -> None: """ @@ -419,7 +422,7 @@ def saveImpl( @staticmethod def load( - metadata: Dict[str, Any], sc: SparkContext, path: str + metadata: Dict[str, Any], sc: "SparkContext", path: str ) -> Tuple[str, List["PipelineStage"]]: """ Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel` diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py index 3ac77b4098219..ec5da94079ea3 100644 --- a/python/pyspark/ml/stat.py +++ b/python/pyspark/ml/stat.py @@ -18,7 +18,7 @@ import sys from typing import Optional, Tuple, TYPE_CHECKING -from pyspark import since, SparkContext +from pyspark import since from pyspark.ml.common import _java2py, _py2java from pyspark.ml.linalg import Matrix, Vector from pyspark.ml.wrapper import JavaWrapper, _jvm @@ -102,6 +102,8 @@ def test( >>> row[0].statistic 4.0 """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None @@ -171,6 +173,8 @@ def corr(dataset: DataFrame, column: str, method: str = "pearson") -> DataFrame: [ NaN, NaN, 1. , NaN], [ 0.4 , 0.9486... , NaN, 1. ]]) """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None @@ -239,6 +243,8 @@ def test(dataset: DataFrame, sampleCol: str, distName: str, *params: float) -> D >>> round(ksResult.statistic, 3) 0.175 """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None @@ -424,6 +430,8 @@ def metrics(*metrics: str) -> "SummaryBuilder": ------- :py:class:`pyspark.ml.stat.SummaryBuilder` """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ae028b2f39969..e8713d81c4d62 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -37,7 +37,7 @@ import numpy as np -from pyspark import keyword_only, since, SparkContext, inheritable_thread_target +from pyspark import keyword_only, since, inheritable_thread_target from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.common import inherit_doc, _py2java, _java2py from pyspark.ml.evaluation import Evaluator, JavaEvaluator @@ -63,6 +63,7 @@ from pyspark.ml._typing import ParamMap from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray + from pyspark.core.context import SparkContext __all__ = [ "ParamGridBuilder", @@ -272,6 +273,7 @@ def _to_java_impl(self) -> Tuple["JavaObject", "JavaObject", "JavaObject"]: """ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. """ + from pyspark.core.context import SparkContext gateway = SparkContext._gateway assert gateway is not None and SparkContext._jvm is not None @@ -301,6 +303,8 @@ class _ValidatorSharedReadWrite: def meta_estimator_transfer_param_maps_to_java( pyEstimator: Estimator, pyParamMaps: Sequence["ParamMap"] ) -> "JavaArray": + from pyspark.core.context import SparkContext + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages)) sc = SparkContext._active_spark_context @@ -335,6 +339,8 @@ def meta_estimator_transfer_param_maps_to_java( def meta_estimator_transfer_param_maps_from_java( pyEstimator: Estimator, javaParamMaps: "JavaArray" ) -> List["ParamMap"]: + from pyspark.core.context import SparkContext + pyStages = MetaAlgorithmReadWrite.getAllNestedStages(pyEstimator) stagePairs = list(map(lambda stage: (stage, cast(JavaParams, stage)._to_java()), pyStages)) sc = SparkContext._active_spark_context @@ -380,7 +386,7 @@ def is_java_convertible(instance: _ValidatorParams) -> bool: def saveImpl( path: str, instance: _ValidatorParams, - sc: SparkContext, + sc: "SparkContext", extraMetadata: Optional[Dict[str, Any]] = None, ) -> None: numParamsNotJson = 0 @@ -424,7 +430,7 @@ def saveImpl( @staticmethod def load( - path: str, sc: SparkContext, metadata: Dict[str, Any] + path: str, sc: "SparkContext", metadata: Dict[str, Any] ) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]: evaluatorPath = os.path.join(path, "evaluator") evaluator: Evaluator = DefaultParamsReader.loadParamsInstance(evaluatorPath, sc) @@ -1089,6 +1095,8 @@ def _from_java(cls, java_stage: "JavaObject") -> "CrossValidatorModel": Given a Java CrossValidatorModel, create and return a Python wrapper of it. Used for ML persistence. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None @@ -1126,6 +1134,7 @@ def _to_java(self) -> "JavaObject": py4j.java_gateway.JavaObject Java object equivalent to this instance. """ + from pyspark.core.context import SparkContext sc = SparkContext._active_spark_context assert sc is not None @@ -1648,6 +1657,7 @@ def _from_java(cls, java_stage: "JavaObject") -> "TrainValidationSplitModel": Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. Used for ML persistence. """ + from pyspark.core.context import SparkContext # Load information from java_stage to the instance. sc = SparkContext._active_spark_context @@ -1687,6 +1697,7 @@ def _to_java(self) -> "JavaObject": py4j.java_gateway.JavaObject Java object equivalent to this instance. """ + from pyspark.core.context import SparkContext sc = SparkContext._active_spark_context assert sc is not None diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b6e3ea2a51a67..b9a2829a1ca0b 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -34,7 +34,7 @@ TYPE_CHECKING, ) -from pyspark import SparkContext, since +from pyspark import since from pyspark.ml.common import inherit_doc from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -45,6 +45,7 @@ from pyspark.ml._typing import PipelineStage from pyspark.ml.base import Params from pyspark.ml.wrapper import JavaWrapper + from pyspark.core.context import SparkContext T = TypeVar("T") RW = TypeVar("RW", bound="BaseReadWrite") @@ -61,6 +62,8 @@ def _jvm() -> "JavaGateway": Returns the JVM view associated with SparkContext. Must be called after SparkContext is initialized. """ + from pyspark.core.context import SparkContext + jvm = SparkContext._jvm if jvm: return jvm @@ -119,7 +122,7 @@ def sparkSession(self) -> SparkSession: return self._sparkSession @property - def sc(self) -> SparkContext: + def sc(self) -> "SparkContext": """ Returns the underlying `SparkContext`. """ @@ -435,7 +438,7 @@ def extractJsonParams(instance: "Params", skipParams: Sequence[str]) -> Dict[str def saveMetadata( instance: "Params", path: str, - sc: SparkContext, + sc: "SparkContext", extraMetadata: Optional[Dict[str, Any]] = None, paramMap: Optional[Dict[str, Any]] = None, ) -> None: @@ -466,7 +469,7 @@ def saveMetadata( @staticmethod def _get_metadata_to_save( instance: "Params", - sc: SparkContext, + sc: "SparkContext", extraMetadata: Optional[Dict[str, Any]] = None, paramMap: Optional[Dict[str, Any]] = None, ) -> str: @@ -562,7 +565,7 @@ def load(self, path: str) -> RL: return instance @staticmethod - def loadMetadata(path: str, sc: SparkContext, expectedClassName: str = "") -> Dict[str, Any]: + def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> Dict[str, Any]: """ Load metadata saved using :py:meth:`DefaultParamsWriter.saveMetadata` @@ -634,7 +637,7 @@ def isPythonParamsInstance(metadata: Dict[str, Any]) -> bool: return metadata["class"].startswith("pyspark.ml.") @staticmethod - def loadParamsInstance(path: str, sc: SparkContext) -> RL: + def loadParamsInstance(path: str, sc: "SparkContext") -> RL: """ Load a :py:class:`Params` instance from the given path, and return it. This assumes the instance inherits from :py:class:`MLReadable`. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index ea2a38cd91015..eed7781dc71e3 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -19,7 +19,6 @@ from typing import Any, Generic, Optional, List, Type, TypeVar, TYPE_CHECKING from pyspark import since -from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml import Estimator, Predictor, PredictionModel, Transformer, Model from pyspark.ml.base import _PredictorParams @@ -49,6 +48,8 @@ def __init__(self, java_obj: Optional["JavaObject"] = None): self._java_obj = java_obj def __del__(self) -> None: + from pyspark.core.context import SparkContext + if SparkContext._active_spark_context and self._java_obj is not None: SparkContext._active_spark_context._gateway.detach( # type: ignore[union-attr] self._java_obj @@ -63,6 +64,8 @@ def _create_from_java_class(cls: Type[JW], java_class: str, *args: Any) -> JW: return cls(java_obj) def _call_java(self, name: str, *args: Any) -> Any: + from pyspark.core.context import SparkContext + m = getattr(self._java_obj, name) sc = SparkContext._active_spark_context assert sc is not None @@ -75,6 +78,8 @@ def _new_java_obj(java_class: str, *args: Any) -> "JavaObject": """ Returns a new Java object. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None @@ -114,6 +119,8 @@ def _new_java_array(pylist: List[Any], java_class: "JavaClass") -> "JavaObject": :py:class:`py4j.java_collections.JavaArray` Java Array of converted pylist. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None assert sc._gateway is not None @@ -150,6 +157,8 @@ def _make_java_param_pair(self, param: Param[T], value: T) -> "JavaObject": """ Makes a Java param pair. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and self._java_obj is not None @@ -162,6 +171,8 @@ def _transfer_params_to_java(self) -> None: """ Transforms the embedded params to the companion Java object. """ + from pyspark.core.context import SparkContext + assert self._java_obj is not None pair_defaults = [] @@ -211,6 +222,8 @@ def _transfer_params_from_java(self) -> None: """ Transforms the embedded params from the companion Java object. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None and self._java_obj is not None @@ -234,6 +247,8 @@ def _transfer_param_map_from_java(self, javaParamMap: "JavaObject") -> "ParamMap """ Transforms a Java ParamMap into a Python ParamMap. """ + from pyspark.core.context import SparkContext + sc = SparkContext._active_spark_context assert sc is not None