Skip to content

Commit

Permalink
[SPARK-47807][PYTHON][ML] Make pyspark.ml compatible with pyspark-con…
Browse files Browse the repository at this point in the history
…nect

### What changes were proposed in this pull request?

This PR proposes to make `pyspark.ml` compatible with `pyspark-connect`.

### Why are the changes needed?

In order for `pyspark-connect` to work without classic PySpark packages and dependencies.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Yes, at apache#45941. Once CI is setup there, it will be tested there properly.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45995 from HyukjinKwon/SPARK-47807.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed Apr 11, 2024
1 parent c303b04 commit 6fdf9c9
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 35 deletions.
11 changes: 8 additions & 3 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
TYPE_CHECKING,
)

from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
from pyspark import keyword_only, since, inheritable_thread_target
from pyspark.ml import Estimator, Predictor, PredictionModel, Model
from pyspark.ml.param.shared import (
HasRawPredictionCol,
Expand Down Expand Up @@ -97,6 +97,7 @@
if TYPE_CHECKING:
from pyspark.ml._typing import P, ParamMap
from py4j.java_gateway import JavaObject
from pyspark.core.context import SparkContext


T = TypeVar("T")
Expand Down Expand Up @@ -3677,7 +3678,7 @@ class _OneVsRestSharedReadWrite:
@staticmethod
def saveImpl(
instance: Union[OneVsRest, "OneVsRestModel"],
sc: SparkContext,
sc: "SparkContext",
path: str,
extraMetadata: Optional[Dict[str, Any]] = None,
) -> None:
Expand All @@ -3690,7 +3691,7 @@ def saveImpl(
cast(MLWritable, instance.getClassifier()).save(classifierPath)

@staticmethod
def loadClassifier(path: str, sc: SparkContext) -> Union[OneVsRest, "OneVsRestModel"]:
def loadClassifier(path: str, sc: "SparkContext") -> Union[OneVsRest, "OneVsRestModel"]:
classifierPath = os.path.join(path, "classifier")
return DefaultParamsReader.loadParamsInstance(classifierPath, sc)

Expand Down Expand Up @@ -3771,6 +3772,8 @@ def setRawPredictionCol(self, value: str) -> "OneVsRestModel":

def __init__(self, models: List[ClassificationModel]):
super(OneVsRestModel, self).__init__()
from pyspark.core.context import SparkContext

self.models = models
if not isinstance(models[0], JavaMLWritable):
return
Expand Down Expand Up @@ -3913,6 +3916,8 @@ def _to_java(self) -> "JavaObject":
py4j.java_gateway.JavaObject
Java object equivalent to this instance.
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None

Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
JavaMLReadable,
GeneralJavaMLWritable,
HasTrainingSummary,
SparkContext,
)
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaWrapper
from pyspark.ml.common import inherit_doc, _java2py
Expand Down Expand Up @@ -226,6 +225,8 @@ def gaussians(self) -> List[MultivariateGaussian]:
Array of :py:class:`MultivariateGaussian` where gaussians[i] represents
the Multivariate Gaussian (Normal) Distribution for Gaussian i
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and self._java_obj is not None

Expand Down
45 changes: 31 additions & 14 deletions python/pyspark/ml/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@

from typing import Any, Callable, TYPE_CHECKING

import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
from py4j.java_collections import JavaArray, JavaList

import pyspark.core.context
from pyspark import RDD, SparkContext
from pyspark.util import is_remote_only
from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer
from pyspark.sql import DataFrame, SparkSession

if TYPE_CHECKING:
import py4j.protocol
from py4j.java_gateway import JavaObject

import pyspark.core.context
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext
from pyspark.ml._typing import C, JavaObjectOrPickleDump

# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode

if not is_remote_only():
import py4j

# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode

_float_str_mapping = {
"nan": "NaN",
Expand All @@ -47,7 +51,10 @@ def _new_smart_decode(obj: Any) -> str:
return _old_smart_decode(obj)


py4j.protocol.smart_decode = _new_smart_decode
if not is_remote_only():
import py4j

py4j.protocol.smart_decode = _new_smart_decode


_picklable_classes = [
Expand All @@ -59,7 +66,7 @@ def _new_smart_decode(obj: Any) -> str:


# this will call the ML version of pythonToJava()
def _to_java_object_rdd(rdd: RDD) -> JavaObject:
def _to_java_object_rdd(rdd: "RDD") -> "JavaObject":
"""Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pickle, whenever the
Expand All @@ -70,8 +77,12 @@ def _to_java_object_rdd(rdd: RDD) -> JavaObject:
return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)


def _py2java(sc: SparkContext, obj: Any) -> JavaObject:
def _py2java(sc: "SparkContext", obj: Any) -> "JavaObject":
"""Convert Python object into Java"""
from py4j.java_gateway import JavaObject
from pyspark.core.rdd import RDD
from pyspark.core.context import SparkContext

if isinstance(obj, RDD):
obj = _to_java_object_rdd(obj)
elif isinstance(obj, DataFrame):
Expand All @@ -91,7 +102,11 @@ def _py2java(sc: SparkContext, obj: Any) -> JavaObject:
return obj


def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any:
def _java2py(sc: "SparkContext", r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any:
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
from py4j.java_collections import JavaArray, JavaList

if isinstance(r, JavaObject):
clsName = r.getClass().getSimpleName()
# convert RDD into JavaRDD
Expand Down Expand Up @@ -122,7 +137,9 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt


def callJavaFunc(
sc: pyspark.core.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any
sc: "pyspark.core.context.SparkContext",
func: Callable[..., "JavaObjectOrPickleDump"],
*args: Any,
) -> "JavaObjectOrPickleDump":
"""Call Java Function"""
java_args = [_py2java(sc, a) for a in args]
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
TYPE_CHECKING,
)

from pyspark import keyword_only, since, SparkContext
from pyspark import keyword_only, since
from pyspark.ml.linalg import _convert_to_vector, DenseMatrix, DenseVector, Vector
from pyspark.sql.dataframe import DataFrame
from pyspark.ml.param.shared import (
Expand Down Expand Up @@ -1202,6 +1202,8 @@ def from_vocabulary(
Construct the model directly from a vocabulary list of strings,
requires an active SparkContext.
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None
java_class = sc._gateway.jvm.java.lang.String
Expand Down Expand Up @@ -4791,6 +4793,8 @@ def from_labels(
Construct the model directly from an array of label strings,
requires an active SparkContext.
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None
java_class = sc._gateway.jvm.java.lang.String
Expand Down Expand Up @@ -4818,6 +4822,8 @@ def from_arrays_of_labels(
Construct the model directly from an array of array of label strings,
requires an active SparkContext.
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._gateway is not None
java_class = sc._gateway.jvm.java.lang.String
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/ml/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
except ImportError:
pass # Let it throw a better error message later when the API is invoked.

from pyspark import SparkContext
from pyspark.sql.functions import pandas_udf
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.types import (
Expand Down Expand Up @@ -116,6 +115,8 @@ def vector_to_array(col: Column, dtype: str = "float64") -> Column:
[StructField('vec', ArrayType(FloatType(), False), False),
StructField('oldVec', ArrayType(FloatType(), False), False)]
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._jvm is not None
return Column(
Expand Down Expand Up @@ -157,6 +158,8 @@ def array_to_vector(col: Column) -> Column:
>>> df3.select(array_to_vector('v1').alias('vec1')).collect()
[Row(vec1=DenseVector([1.0, 3.0]))]
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None and sc._jvm is not None
return Column(sc._jvm.org.apache.spark.ml.functions.array_to_vector(_to_java_column(col)))
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/ml/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import numpy as np

from pyspark import SparkContext
from pyspark.sql.types import Row, StructType, _create_row, _parse_datatype_json_string
from pyspark.sql import SparkSession

Expand Down Expand Up @@ -63,6 +62,7 @@ def imageSchema(self) -> StructType:
.. versionadded:: 2.3.0
"""
from pyspark.core.context import SparkContext

if self._imageSchema is None:
ctx = SparkContext._active_spark_context
Expand All @@ -83,6 +83,7 @@ def ocvTypes(self) -> Dict[str, int]:
.. versionadded:: 2.3.0
"""
from pyspark.core.context import SparkContext

if self._ocvTypes is None:
ctx = SparkContext._active_spark_context
Expand All @@ -103,6 +104,7 @@ def columnSchema(self) -> StructType:
.. versionadded:: 2.4.0
"""
from pyspark.core.context import SparkContext

if self._columnSchema is None:
ctx = SparkContext._active_spark_context
Expand All @@ -123,6 +125,7 @@ def imageFields(self) -> List[str]:
.. versionadded:: 2.3.0
"""
from pyspark.core.context import SparkContext

if self._imageFields is None:
ctx = SparkContext._active_spark_context
Expand All @@ -137,6 +140,7 @@ def undefinedImageType(self) -> str:
.. versionadded:: 2.3.0
"""
from pyspark.core.context import SparkContext

if self._undefinedImageType is None:
ctx = SparkContext._active_spark_context
Expand Down
9 changes: 6 additions & 3 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast, TYPE_CHECKING

from pyspark import keyword_only, since, SparkContext
from pyspark import keyword_only, since
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import (
Expand All @@ -40,6 +40,7 @@
if TYPE_CHECKING:
from pyspark.ml._typing import ParamMap, PipelineStage
from py4j.java_gateway import JavaObject
from pyspark.core.context import SparkContext


@inherit_doc
Expand Down Expand Up @@ -200,6 +201,7 @@ def _to_java(self) -> "JavaObject":
py4j.java_gateway.JavaObject
Java object equivalent to this instance.
"""
from pyspark.core.context import SparkContext

gateway = SparkContext._gateway
assert gateway is not None and SparkContext._jvm is not None
Expand Down Expand Up @@ -353,6 +355,7 @@ def _to_java(self) -> "JavaObject":
:return: Java object equivalent to this instance.
"""
from pyspark.core.context import SparkContext

gateway = SparkContext._gateway
assert gateway is not None and SparkContext._jvm is not None
Expand Down Expand Up @@ -400,7 +403,7 @@ def validateStages(stages: List["PipelineStage"]) -> None:
def saveImpl(
instance: Union[Pipeline, PipelineModel],
stages: List["PipelineStage"],
sc: SparkContext,
sc: "SparkContext",
path: str,
) -> None:
"""
Expand All @@ -419,7 +422,7 @@ def saveImpl(

@staticmethod
def load(
metadata: Dict[str, Any], sc: SparkContext, path: str
metadata: Dict[str, Any], sc: "SparkContext", path: str
) -> Tuple[str, List["PipelineStage"]]:
"""
Load metadata and stages for a :py:class:`Pipeline` or :py:class:`PipelineModel`
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/ml/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import sys
from typing import Optional, Tuple, TYPE_CHECKING

from pyspark import since, SparkContext
from pyspark import since
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.wrapper import JavaWrapper, _jvm
Expand Down Expand Up @@ -102,6 +102,8 @@ def test(
>>> row[0].statistic
4.0
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None

Expand Down Expand Up @@ -171,6 +173,8 @@ def corr(dataset: DataFrame, column: str, method: str = "pearson") -> DataFrame:
[ NaN, NaN, 1. , NaN],
[ 0.4 , 0.9486... , NaN, 1. ]])
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None

Expand Down Expand Up @@ -239,6 +243,8 @@ def test(dataset: DataFrame, sampleCol: str, distName: str, *params: float) -> D
>>> round(ksResult.statistic, 3)
0.175
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None

Expand Down Expand Up @@ -424,6 +430,8 @@ def metrics(*metrics: str) -> "SummaryBuilder":
-------
:py:class:`pyspark.ml.stat.SummaryBuilder`
"""
from pyspark.core.context import SparkContext

sc = SparkContext._active_spark_context
assert sc is not None

Expand Down
Loading

0 comments on commit 6fdf9c9

Please sign in to comment.