This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 762599cc64df [SPARK-50940][ML][PYTHON][CONNECT] Adds support 
CrossValidator/CrossValidatorModel on connect
762599cc64df is described below

commit 762599cc64df36ebff2b794dd6c77f48c428749f
Author: Bobby Wang <[email protected]>
AuthorDate: Sun Jan 26 14:41:15 2025 +0800

    [SPARK-50940][ML][PYTHON][CONNECT] Adds support 
CrossValidator/CrossValidatorModel on connect
    
    ### What changes were proposed in this pull request?
    
    Support CrossValidator/CrossValidatorModel on connect
    
    ### Why are the changes needed?
    
    for parity feature
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    The newly added test pass
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49644 from wbo4958/cv.
    
    Authored-by: Bobby Wang <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |  2 +
 python/pyspark/ml/connect/readwrite.py             | 48 ++++++++++-
 .../pyspark/ml/tests/connect/test_parity_tuning.py | 56 +++++++++++++
 python/pyspark/ml/tests/test_tuning.py             | 98 ++++++++++++++++++++++
 python/pyspark/ml/tuning.py                        | 32 ++++---
 python/pyspark/ml/util.py                          |  3 +-
 6 files changed, 224 insertions(+), 15 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f76427f79761..8e518b0febe3 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -675,6 +675,7 @@ pyspark_ml = Module(
         "pyspark.ml.tests.test_param",
         "pyspark.ml.tests.test_persistence",
         "pyspark.ml.tests.test_pipeline",
+        "pyspark.ml.tests.test_tuning",
         "pyspark.ml.tests.test_stat",
         "pyspark.ml.tests.test_training_summary",
         "pyspark.ml.tests.tuning.test_tuning",
@@ -1127,6 +1128,7 @@ pyspark_ml_connect = Module(
         "pyspark.ml.tests.connect.test_parity_evaluation",
         "pyspark.ml.tests.connect.test_parity_feature",
         "pyspark.ml.tests.connect.test_parity_pipeline",
+        "pyspark.ml.tests.connect.test_parity_tuning",
     ],
     excluded_python_implementations=[
         "PyPy"  # Skip these tests under PyPy since they require numpy, 
pandas, and pyarrow and
diff --git a/python/pyspark/ml/connect/readwrite.py 
b/python/pyspark/ml/connect/readwrite.py
index eea966270871..3bf2031538d9 100644
--- a/python/pyspark/ml/connect/readwrite.py
+++ b/python/pyspark/ml/connect/readwrite.py
@@ -14,12 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
 import warnings
-from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any
+from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
 
 import pyspark.sql.connect.proto as pb2
 from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, 
deserialize_param
+from pyspark.ml.tuning import CrossValidatorModelWriter, CrossValidatorModel
 from pyspark.ml.util import MLWriter, MLReader, RL
 from pyspark.ml.wrapper import JavaWrapper
 
@@ -29,6 +29,19 @@ if TYPE_CHECKING:
     from pyspark.ml.util import JavaMLReadable, JavaMLWritable
 
 
+class RemoteCrossValidatorModelWriter(CrossValidatorModelWriter):
+    def __init__(
+        self,
+        instance: "CrossValidatorModel",
+        optionMap: Dict[str, Any] = {},
+        session: Optional["SparkSession"] = None,
+    ):
+        super(RemoteCrossValidatorModelWriter, self).__init__(instance)
+        self.instance = instance
+        self.optionMap = optionMap
+        self.session(session)  # type: ignore[arg-type]
+
+
 class RemoteMLWriter(MLWriter):
     def __init__(self, instance: "JavaMLWritable") -> None:
         super().__init__()
@@ -63,6 +76,7 @@ class RemoteMLWriter(MLWriter):
         from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
+        from pyspark.ml.tuning import CrossValidator
 
         # Spark Connect ML is built on scala Spark.ML, that means we're only
         # supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -126,6 +140,21 @@ class RemoteMLWriter(MLWriter):
                 path,
             )
 
+        elif isinstance(instance, CrossValidator):
+            from pyspark.ml.tuning import CrossValidatorWriter
+
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
CrossValidator")
+            cv_writer = CrossValidatorWriter(instance)
+            cv_writer.session(session)  # type: ignore[arg-type]
+            cv_writer.save(path)
+        elif isinstance(instance, CrossValidatorModel):
+            if shouldOverwrite:
+                # TODO(SPARK-50954): Support client side model path overwrite
+                warnings.warn("Overwrite doesn't take effect for 
CrossValidatorModel")
+            cvm_writer = RemoteCrossValidatorModelWriter(instance, optionMap, 
session)
+            cvm_writer.save(path)
         else:
             raise NotImplementedError(f"Unsupported write for 
{instance.__class__}")
 
@@ -153,6 +182,7 @@ class RemoteMLReader(MLReader[RL]):
         from pyspark.ml.wrapper import JavaModel, JavaEstimator, 
JavaTransformer
         from pyspark.ml.evaluation import JavaEvaluator
         from pyspark.ml.pipeline import Pipeline, PipelineModel
+        from pyspark.ml.tuning import CrossValidator
 
         if (
             issubclass(clazz, JavaModel)
@@ -217,5 +247,19 @@ class RemoteMLReader(MLReader[RL]):
             else:
                 return PipelineModel(stages=cast(List[Transformer], 
stages))._resetUid(uid)
 
+        elif issubclass(clazz, CrossValidator):
+            from pyspark.ml.tuning import CrossValidatorReader
+
+            cv_reader = CrossValidatorReader(CrossValidator)
+            cv_reader.session(session)
+            return cv_reader.load(path)
+
+        elif issubclass(clazz, CrossValidatorModel):
+            from pyspark.ml.tuning import CrossValidatorModelReader
+
+            cvm_reader = CrossValidatorModelReader(CrossValidator)
+            cvm_reader.session(session)
+            return cvm_reader.load(path)
+
         else:
             raise RuntimeError(f"Unsupported read for {clazz}")
diff --git a/python/pyspark/ml/tests/connect/test_parity_tuning.py 
b/python/pyspark/ml/tests/connect/test_parity_tuning.py
new file mode 100644
index 000000000000..473661d6c288
--- /dev/null
+++ b/python/pyspark/ml/tests/connect/test_parity_tuning.py
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.connect.readwrite import RemoteCrossValidatorModelWriter
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.tests.test_tuning import TuningTestsMixin
+from pyspark.ml.tuning import CrossValidatorModel
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class TuningParityTests(TuningTestsMixin, ReusedConnectTestCase):
+    def test_remote_cross_validator_model_writer(self):
+        df = self.spark.createDataFrame(
+            [
+                (1.0, 1.0, Vectors.dense(0.0, 5.0)),
+                (0.0, 2.0, Vectors.dense(1.0, 2.0)),
+                (1.0, 3.0, Vectors.dense(2.0, 1.0)),
+                (0.0, 4.0, Vectors.dense(3.0, 3.0)),
+            ],
+            ["label", "weight", "features"],
+        )
+
+        lor = LogisticRegression()
+        lor_model = lor.fit(df)
+        cv_model = CrossValidatorModel(lor_model)
+        writer = RemoteCrossValidatorModelWriter(cv_model, {"a": "b"}, 
self.spark)
+        self.assertEqual(writer.optionMap["a"], "b")
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.connect.test_parity_tuning import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tests/test_tuning.py 
b/python/pyspark/ml/tests/test_tuning.py
new file mode 100644
index 000000000000..2bc0e22c1209
--- /dev/null
+++ b/python/pyspark/ml/tests/test_tuning.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+import unittest
+
+import numpy as np
+
+from pyspark.ml.evaluation import BinaryClassificationEvaluator
+from pyspark.ml.linalg import Vectors
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, 
CrossValidatorModel
+from pyspark.testing.sqlutils import ReusedSQLTestCase
+
+
+class TuningTestsMixin:
+    def test_cross_validator(self):
+        dataset = self.spark.createDataFrame(
+            [
+                (Vectors.dense([0.0]), 0.0),
+                (Vectors.dense([0.4]), 1.0),
+                (Vectors.dense([0.5]), 0.0),
+                (Vectors.dense([0.6]), 1.0),
+                (Vectors.dense([1.0]), 1.0),
+            ]
+            * 10,
+            ["features", "label"],
+        )
+        lr = LogisticRegression()
+        grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+        evaluator = BinaryClassificationEvaluator()
+        cv = CrossValidator(
+            estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, 
parallelism=1
+        )
+
+        self.assertEqual(cv.getEstimator(), lr)
+        self.assertEqual(cv.getEvaluator(), evaluator)
+        self.assertEqual(cv.getParallelism(), 1)
+        self.assertEqual(cv.getEstimatorParamMaps(), grid)
+
+        model = cv.fit(dataset)
+        self.assertEqual(model.getEstimator(), lr)
+        self.assertEqual(model.getEvaluator(), evaluator)
+        self.assertEqual(model.getEstimatorParamMaps(), grid)
+        self.assertTrue(np.isclose(model.avgMetrics[0], 0.5, atol=1e-4))
+
+        output = model.transform(dataset)
+        self.assertEqual(
+            output.columns, ["features", "label", "rawPrediction", 
"probability", "prediction"]
+        )
+        self.assertEqual(output.count(), 50)
+
+        # save & load
+        with tempfile.TemporaryDirectory(prefix="cv_lr") as d:
+            path1 = os.path.join(d, "cv")
+            cv.write().save(path1)
+            cv2 = CrossValidator.load(path1)
+            self.assertEqual(str(cv), str(cv2))
+            self.assertEqual(str(cv.getEstimator()), str(cv2.getEstimator()))
+            self.assertEqual(str(cv.getEvaluator()), str(cv2.getEvaluator()))
+
+            path2 = os.path.join(d, "cv_model")
+            model.write().save(path2)
+            model2 = CrossValidatorModel.load(path2)
+            self.assertEqual(str(model), str(model2))
+            self.assertEqual(str(model.getEstimator()), 
str(model2.getEstimator()))
+            self.assertEqual(str(model.getEvaluator()), 
str(model2.getEvaluator()))
+
+
+class TuningTests(TuningTestsMixin, ReusedSQLTestCase):
+    pass
+
+
+if __name__ == "__main__":
+    from pyspark.ml.tests.test_tuning import *  # noqa: F401
+
+    try:
+        import xmlrunner
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 695bbf98517c..06d3837e1ae4 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -53,6 +53,8 @@ from pyspark.ml.util import (
     MLWriter,
     JavaMLReader,
     JavaMLWriter,
+    try_remote_write,
+    try_remote_read,
 )
 from pyspark.ml.wrapper import JavaParams, JavaEstimator, JavaWrapper
 from pyspark.sql.functions import col, lit, rand, UserDefinedFunction
@@ -386,7 +388,7 @@ class _ValidatorSharedReadWrite:
     def saveImpl(
         path: str,
         instance: _ValidatorParams,
-        sc: "SparkContext",
+        sc: Union["SparkContext", "SparkSession"],
         extraMetadata: Optional[Dict[str, Any]] = None,
     ) -> None:
         numParamsNotJson = 0
@@ -430,7 +432,7 @@ class _ValidatorSharedReadWrite:
 
     @staticmethod
     def load(
-        path: str, sc: "SparkContext", metadata: Dict[str, Any]
+        path: str, sc: Union["SparkContext", "SparkSession"], metadata: 
Dict[str, Any]
     ) -> Tuple[Dict[str, Any], Estimator, Evaluator, List["ParamMap"]]:
         evaluatorPath = os.path.join(path, "evaluator")
         evaluator: Evaluator = 
DefaultParamsReader.loadParamsInstance(evaluatorPath, sc)
@@ -513,12 +515,12 @@ class CrossValidatorReader(MLReader["CrossValidator"]):
         self.cls = cls
 
     def load(self, path: str) -> "CrossValidator":
-        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
         if not DefaultParamsReader.isPythonParamsInstance(metadata):
             return JavaMLReader(self.cls).load(path)  # type: ignore[arg-type]
         else:
             metadata, estimator, evaluator, estimatorParamMaps = 
_ValidatorSharedReadWrite.load(
-                path, self.sc, metadata
+                path, self.sparkSession, metadata
             )
             cv = CrossValidator(
                 estimator=estimator, estimatorParamMaps=estimatorParamMaps, 
evaluator=evaluator
@@ -536,7 +538,7 @@ class CrossValidatorWriter(MLWriter):
 
     def saveImpl(self, path: str) -> None:
         _ValidatorSharedReadWrite.validateParams(self.instance)
-        _ValidatorSharedReadWrite.saveImpl(path, self.instance, self.sc)
+        _ValidatorSharedReadWrite.saveImpl(path, self.instance, 
self.sparkSession)
 
 
 @inherit_doc
@@ -546,16 +548,18 @@ class 
CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
         self.cls = cls
 
     def load(self, path: str) -> "CrossValidatorModel":
-        metadata = DefaultParamsReader.loadMetadata(path, self.sc)
+        metadata = DefaultParamsReader.loadMetadata(path, self.sparkSession)
         if not DefaultParamsReader.isPythonParamsInstance(metadata):
             return JavaMLReader(self.cls).load(path)  # type: ignore[arg-type]
         else:
             metadata, estimator, evaluator, estimatorParamMaps = 
_ValidatorSharedReadWrite.load(
-                path, self.sc, metadata
+                path, self.sparkSession, metadata
             )
             numFolds = metadata["paramMap"]["numFolds"]
             bestModelPath = os.path.join(path, "bestModel")
-            bestModel: Model = 
DefaultParamsReader.loadParamsInstance(bestModelPath, self.sc)
+            bestModel: Model = DefaultParamsReader.loadParamsInstance(
+                bestModelPath, self.sparkSession
+            )
             avgMetrics = metadata["avgMetrics"]
             if "stdMetrics" in metadata:
                 stdMetrics = metadata["stdMetrics"]
@@ -571,7 +575,7 @@ class 
CrossValidatorModelReader(MLReader["CrossValidatorModel"]):
                             path, "subModels", f"fold{splitIndex}", 
f"{paramIndex}"
                         )
                         subModels[splitIndex][paramIndex] = 
DefaultParamsReader.loadParamsInstance(
-                            modelPath, self.sc
+                            modelPath, self.sparkSession
                         )
             else:
                 subModels = None
@@ -608,7 +612,9 @@ class CrossValidatorModelWriter(MLWriter):
         if instance.stdMetrics:
             extraMetadata["stdMetrics"] = instance.stdMetrics
 
-        _ValidatorSharedReadWrite.saveImpl(path, instance, self.sc, 
extraMetadata=extraMetadata)
+        _ValidatorSharedReadWrite.saveImpl(
+            path, instance, self.sparkSession, extraMetadata=extraMetadata
+        )
         bestModelPath = os.path.join(path, "bestModel")
         cast(MLWritable, instance.bestModel).save(bestModelPath)
         if persistSubModels:
@@ -845,7 +851,7 @@ class CrossValidator(
             train = datasets[i][0].cache()
 
             tasks = map(
-                inheritable_thread_target,
+                inheritable_thread_target(dataset.sparkSession),
                 _parallelFitTasks(est, train, eva, validation, epm, 
collectSubModelsParam),
             )
             for j, metric, subModel in pool.imap_unordered(lambda f: f(), 
tasks):
@@ -939,6 +945,7 @@ class CrossValidator(
         return newCV
 
     @since("2.3.0")
+    @try_remote_write
     def write(self) -> MLWriter:
         """Returns an MLWriter instance for this ML instance."""
         if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -947,6 +954,7 @@ class CrossValidator(
 
     @classmethod
     @since("2.3.0")
+    @try_remote_read
     def read(cls) -> CrossValidatorReader:
         """Returns an MLReader instance for this class."""
         return CrossValidatorReader(cls)
@@ -1077,6 +1085,7 @@ class CrossValidatorModel(
         )
 
     @since("2.3.0")
+    @try_remote_write
     def write(self) -> MLWriter:
         """Returns an MLWriter instance for this ML instance."""
         if _ValidatorSharedReadWrite.is_java_convertible(self):
@@ -1085,6 +1094,7 @@ class CrossValidatorModel(
 
     @classmethod
     @since("2.3.0")
+    @try_remote_read
     def read(cls) -> CrossValidatorModelReader:
         """Returns an MLReader instance for this class."""
         return CrossValidatorModelReader(cls)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6006d131b5c0..309f8452ac79 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -149,11 +149,10 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
     def wrapped(self: "JavaWrapper", dataset: "ConnectDataFrame") -> Any:
         if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
             from pyspark.ml import Model, Transformer
-            from pyspark.sql.connect.session import SparkSession
             from pyspark.sql.connect.dataframe import DataFrame as 
ConnectDataFrame
             from pyspark.ml.connect.serialize import serialize_ml_params
 
-            session = SparkSession.getActiveSession()
+            session = dataset.sparkSession
             assert session is not None
             # Model is also a Transformer, so we much match Model first
             if isinstance(self, Model):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to